From 7a2ff7ad1d0d0da5d325105432bda55d2b27938e Mon Sep 17 00:00:00 2001 From: traxys Date: Thu, 27 Jul 2023 00:06:36 +0200 Subject: [PATCH] Allow to login with OIDC --- Cargo.lock | 288 +++++++++++++++++++++- Cargo.toml | 3 + app/Cargo.toml | 1 + app/src/main.rs | 71 +++++- app/src/recipe/creator.rs | 16 +- app/src/recipe/view.rs | 4 +- migration/src/lib.rs | 2 + migration/src/m20220101_000001_account.rs | 1 + migration/src/m20230726_203858_oidc.rs | 40 +++ src/entity/user.rs | 5 +- src/main.rs | 258 ++++++++++++++++++- src/routes/mod.rs | 122 ++++++++- 12 files changed, 782 insertions(+), 29 deletions(-) create mode 100644 migration/src/m20230726_203858_oidc.rs diff --git a/Cargo.lock b/Cargo.lock index 30e8626..42f7f24 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -739,6 +739,7 @@ dependencies = [ "ident_case", "proc-macro2", "quote", + "strsim", "syn 2.0.16", ] @@ -1000,6 +1001,7 @@ dependencies = [ "log", "pulldown-cmark", "serde", + "urlencoding", "uuid", "wasm-bindgen", ] @@ -1030,6 +1032,12 @@ version = "0.15.7" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1aaf95b3e5c8f23aa320147307562d361db0ae0d51242340f558153b4eb2439b" +[[package]] +name = "dyn-clone" +version = "1.0.12" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "304e6508efa593091e97a9abbc10f90aa7ca635b6d2784feff3c89d41dd12272" + [[package]] name = "ecdsa" version = "0.16.7" @@ -1081,6 +1089,15 @@ dependencies = [ "zeroize", ] +[[package]] +name = "encoding_rs" +version = "0.8.32" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "071a31f4ee85403370b58aca746f01041ede6f0da2730960ad001edc2b71b394" +dependencies = [ + "cfg-if", +] + [[package]] name = "enumset" version = "1.1.2" @@ -1410,8 +1427,10 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "c85e1d9ab2eadba7e5040d4e09cbd6d072b76a557ad64e797c2cb9d4da21d7e4" dependencies = [ "cfg-if", + "js-sys", "libc", "wasi 0.11.0+wasi-snapshot-preview1", + "wasm-bindgen", ] [[package]] @@ -1509,6 +1528,25 @@ dependencies = [ "subtle", ] +[[package]] +name = "h2" +version = "0.3.20" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "97ec8491ebaf99c8eaa73058b045fe58073cd6be7f596ac993ced0b0a0c01049" +dependencies = [ + "bytes", + "fnv", + "futures-core", + "futures-sink", + "futures-util", + "http", + "indexmap", + "slab", + "tokio", + "tokio-util", + "tracing", +] + [[package]] name = "hashbrown" version = "0.12.3" @@ -1715,6 +1753,7 @@ dependencies = [ "futures-channel", "futures-core", "futures-util", + "h2", "http", "http-body", "httparse", @@ -1728,6 +1767,20 @@ dependencies = [ "want", ] +[[package]] +name = "hyper-rustls" +version = "0.24.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8d78e1e73ec14cf7375674f74d7dde185c8206fd9dea6fb6295e8a98098aaa97" +dependencies = [ + "futures-util", + "http", + "hyper", + "rustls 0.21.5", + "tokio", + "tokio-rustls 0.24.1", +] + [[package]] name = "iana-time-zone" version = "0.1.56" @@ -1792,6 +1845,7 @@ checksum = "bd070e393353796e801d209ad339e89596eb4c8d430d18ede6a1cced8fafbd99" dependencies = [ "autocfg", "hashbrown 0.12.3", + "serde", ] [[package]] @@ -1870,6 +1924,12 @@ dependencies = [ "windows-sys 0.48.0", ] +[[package]] +name = "ipnet" +version = "2.8.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "28b29a3cd74f0f4598934efe3aeba42bae0eb4680554128851ebbecb02af14e6" + [[package]] name = "itertools" version = "0.10.5" @@ -1921,7 +1981,7 @@ dependencies = [ "p256", "p384", "rand", - "rsa", + "rsa 0.7.2", "serde", "serde_json", "spki 0.6.0", @@ -2275,12 +2335,63 @@ dependencies = [ "libc", ] +[[package]] +name = "oauth2" +version = "4.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "09a6e2a2b13a56ebeabba9142f911745be6456163fd6c3d361274ebcd891a80c" +dependencies = [ + "base64 0.13.1", + "chrono", + "getrandom", + "http", + "rand", + "reqwest", + "serde", + "serde_json", + "serde_path_to_error", + "sha2", + "thiserror", + "url", +] + [[package]] name = "once_cell" version = "1.17.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b7e5500299e16ebb147ae15a00a942af264cf3688f47923b8fc2cd5858f23ad3" +[[package]] +name = "openidconnect" +version = "3.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "03335ade401352b354b017e7597ddb40040091da445b031bf659e597e032b1fc" +dependencies = [ + "base64 0.13.1", + "chrono", + "dyn-clone", + "hmac", + "http", + "itertools 0.10.5", + "log", + "oauth2", + "p256", + "p384", + "rand", + "rsa 0.9.2", + "serde", + "serde-value", + "serde_derive", + "serde_json", + "serde_path_to_error", + "serde_plain", + "serde_with", + "sha2", + "subtle", + "thiserror", + "url", +] + [[package]] name = "ordered-float" version = "2.10.0" @@ -2515,6 +2626,17 @@ dependencies = [ "zeroize", ] +[[package]] +name = "pkcs1" +version = "0.7.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c8ffb9f10fa047879315e6625af03c164b16962a5368d724ed16323b68ace47f" +dependencies = [ + "der 0.7.6", + "pkcs8 0.10.2", + "spki 0.7.2", +] + [[package]] name = "pkcs8" version = "0.9.0" @@ -2737,6 +2859,8 @@ dependencies = [ "envious", "jwt-simple", "migration", + "openidconnect", + "parking_lot 0.12.1", "sea-orm", "sea-query", "serde", @@ -2746,6 +2870,7 @@ dependencies = [ "tower-http", "tracing", "tracing-subscriber", + "urlencoding", "uuid", ] @@ -2790,6 +2915,45 @@ dependencies = [ "bytecheck", ] +[[package]] +name = "reqwest" +version = "0.11.18" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cde824a14b7c14f85caff81225f411faacc04a2013f41670f41443742b1c1c55" +dependencies = [ + "base64 0.21.0", + "bytes", + "encoding_rs", + "futures-core", + "futures-util", + "h2", + "http", + "http-body", + "hyper", + "hyper-rustls", + "ipnet", + "js-sys", + "log", + "mime", + "once_cell", + "percent-encoding", + "pin-project-lite", + "rustls 0.21.5", + "rustls-pemfile", + "serde", + "serde_json", + "serde_urlencoded", + "tokio", + "tokio-rustls 0.24.1", + "tower-service", + "url", + "wasm-bindgen", + "wasm-bindgen-futures", + "web-sys", + "webpki-roots", + "winreg", +] + [[package]] name = "rfc6979" version = "0.4.0" @@ -2855,7 +3019,7 @@ dependencies = [ "num-integer", "num-iter", "num-traits", - "pkcs1", + "pkcs1 0.4.1", "pkcs8 0.9.0", "rand_core", "signature 1.6.4", @@ -2864,6 +3028,28 @@ dependencies = [ "zeroize", ] +[[package]] +name = "rsa" +version = "0.9.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6ab43bb47d23c1a631b4b680199a45255dce26fa9ab2fa902581f624ff13e6a8" +dependencies = [ + "byteorder", + "const-oid", + "digest", + "num-bigint-dig", + "num-integer", + "num-iter", + "num-traits", + "pkcs1 0.7.5", + "pkcs8 0.10.2", + "rand_core", + "signature 2.1.0", + "spki 0.7.2", + "subtle", + "zeroize", +] + [[package]] name = "rust_decimal" version = "1.29.1" @@ -2923,6 +3109,18 @@ dependencies = [ "webpki", ] +[[package]] +name = "rustls" +version = "0.21.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "79ea77c539259495ce8ca47f53e66ae0330a8819f67e23ac96ca02f50e7b7d36" +dependencies = [ + "log", + "ring", + "rustls-webpki", + "sct", +] + [[package]] name = "rustls-pemfile" version = "1.0.2" @@ -2932,6 +3130,16 @@ dependencies = [ "base64 0.21.0", ] +[[package]] +name = "rustls-webpki" +version = "0.101.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "15f36a6828982f422756984e47912a7a51dcbc2a197aa791158f8ca61cd8204e" +dependencies = [ + "ring", + "untrusted", +] + [[package]] name = "rustversion" version = "1.0.12" @@ -3219,6 +3427,15 @@ dependencies = [ "serde", ] +[[package]] +name = "serde_plain" +version = "1.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d6018081315db179d0ce57b1fe4b62a12a0028c9cf9bbef868c9cf477b3c34ae" +dependencies = [ + "serde", +] + [[package]] name = "serde_repr" version = "0.1.12" @@ -3242,6 +3459,34 @@ dependencies = [ "serde", ] +[[package]] +name = "serde_with" +version = "3.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "21e47d95bc83ed33b2ecf84f4187ad1ab9685d18ff28db000c99deac8ce180e3" +dependencies = [ + "base64 0.21.0", + "chrono", + "hex", + "indexmap", + "serde", + "serde_json", + "serde_with_macros", + "time 0.3.21", +] + +[[package]] +name = "serde_with_macros" +version = "3.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ea3cee93715c2e266b9338b7544da68a9f24e227722ba482bd1c024367c77c65" +dependencies = [ + "darling", + "proc-macro2", + "quote", + "syn 2.0.16", +] + [[package]] name = "sha1" version = "0.10.5" @@ -3486,7 +3731,7 @@ dependencies = [ "percent-encoding", "rand", "rust_decimal", - "rustls", + "rustls 0.20.8", "rustls-pemfile", "serde", "serde_json", @@ -3532,7 +3777,7 @@ checksum = "804d3f245f894e61b1e6263c84b23ca675d96753b5abfd5cc8597d86806e8024" dependencies = [ "once_cell", "tokio", - "tokio-rustls", + "tokio-rustls 0.23.4", ] [[package]] @@ -3571,6 +3816,12 @@ dependencies = [ "unicode-normalization", ] +[[package]] +name = "strsim" +version = "0.10.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "73473c0e59e6d5812c5dfe2a064a6444949f089e20eec9a2e5506596494e4623" + [[package]] name = "subtle" version = "2.4.1" @@ -3753,11 +4004,21 @@ version = "0.23.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "c43ee83903113e03984cb9e5cebe6c04a5116269e900e3ddba8f068a62adda59" dependencies = [ - "rustls", + "rustls 0.20.8", "tokio", "webpki", ] +[[package]] +name = "tokio-rustls" +version = "0.24.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c28327cf380ac148141087fbfb9de9d7bd4e84ab5d2c28fbc911d753de8a7081" +dependencies = [ + "rustls 0.21.5", + "tokio", +] + [[package]] name = "tokio-stream" version = "0.1.14" @@ -3780,6 +4041,7 @@ dependencies = [ "futures-sink", "pin-project-lite", "tokio", + "tracing", ] [[package]] @@ -3982,8 +4244,15 @@ dependencies = [ "form_urlencoded", "idna", "percent-encoding", + "serde", ] +[[package]] +name = "urlencoding" +version = "2.1.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "daf8dba3b7eb870caf1ddeed7bc9d2a049f3cfdfae7cb521b087cc33ae4c49da" + [[package]] name = "utf-8" version = "0.7.6" @@ -4360,6 +4629,15 @@ version = "0.48.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1a515f5799fe4961cb532f983ce2b23082366b898e52ffbce459c86f67c8378a" +[[package]] +name = "winreg" +version = "0.10.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "80d0f4e272c85def139476380b12f9ac60926689dd2e01d4923222f40580869d" +dependencies = [ + "winapi", +] + [[package]] name = "wyz" version = "0.5.1" diff --git a/Cargo.toml b/Cargo.toml index bd70d7c..43c694a 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -23,7 +23,10 @@ tower-http = { version = "0.4.0", features = ["cors", "fs"] } sha2 = "0.10" uuid = { version = "1.3", features = ["v4"] } sea-query = "0.28" +openidconnect = "3.3.0" envious = "0.2.2" +parking_lot = "0.12.1" +urlencoding = "2.1.3" [dependencies.sea-orm] version = "0.11" diff --git a/app/Cargo.toml b/app/Cargo.toml index a89d61b..aae7e8a 100644 --- a/app/Cargo.toml +++ b/app/Cargo.toml @@ -21,5 +21,6 @@ itertools = "0.11.0" log = "0.4.19" pulldown-cmark = "0.9.3" serde = { version = "1.0.164", features = ["derive"] } +urlencoding = "2.1.3" uuid = "1.4.0" wasm-bindgen = "0.2.87" diff --git a/app/src/main.rs b/app/src/main.rs index b9d9837..5cc25c6 100644 --- a/app/src/main.rs +++ b/app/src/main.rs @@ -6,7 +6,7 @@ use api::{ LoginResponse, UserInfo, }; use dioxus::prelude::*; -use dioxus_router::{use_router, Route, Router}; +use dioxus_router::{use_route, use_router, Redirect, Route, Router}; use gloo_storage::{errors::StorageError, LocalStorage, Storage}; use itertools::Itertools; use serde::{Deserialize, Serialize}; @@ -33,6 +33,11 @@ const API_ROUTE: &str = match option_env!("REGALADE_API_SERVER_BASE") { Some(v) => v, }; +const FRONTEND_ROOT: &str = match option_env!("REGALADE_FRONTEND_DOMAIN") { + None => "http://localhost:8080", + Some(v) => v, +}; + #[macro_export] macro_rules! api { ($($arg:tt)*) => {{ @@ -204,6 +209,37 @@ async fn do_login(username: String, password: String) -> anyhow::Result<()> { Ok(()) } +async fn check_oidc() -> anyhow::Result { + let rsp = gloo_net::http::Request::get(api!("login/has_oidc")) + .send() + .await?; + + Ok(rsp.status() == 200) +} + +fn Openid(cx: Scope) -> Element { + let has = use_future(cx, (), |()| check_oidc()); + + cx.render(match has.value().unwrap_or(&Ok(false)) { + Ok(true) => { + let route = api!("login/oidc").to_owned(); + let ret = urlencoding::encode(&format!("{FRONTEND_ROOT}/login/oidc")).to_string(); + rsx! { + a { + href: "{route}?return={ret}", + class: "mt-1 w-100 btn btn-lg btn-primary", + "Login with OpenID" + } + } + } + Ok(false) => rsx! {{}}, + Err(e) => { + log::error!("Could not check oidc status: {e:?}"); + rsx! {{}} + } + }) +} + fn Login(cx: Scope) -> Element { let error = use_state(cx, || None::); let router = use_router(cx); @@ -256,6 +292,7 @@ fn Login(cx: Scope) -> Element { label { "for": "floatingPass", "Password" } } button { class: "w-100 btn btn-lg btn-primary", "type": "submit", "Login" } + Openid {} } }) } @@ -505,9 +542,38 @@ fn Index(cx: Scope) -> Element { cx.render(rsx! {"INDEX"}) } +#[derive(Deserialize)] +struct OidcQuery { + token: String, + username: String, +} + +fn OidcRedirect(cx: Scope) -> Element { + let auth = use_route(cx).query::(); + + cx.render(match auth { + None => rsx! {"No authentication query, internal error."}, + Some(v) => { + match LocalStorage::set( + "token", + LoginInfo { + token: v.token, + name: v.username, + }, + ) { + Ok(_) => { + gloo_utils::window().location().replace("/").unwrap(); + rsx! {{}} + } + Err(_) => rsx! {"Could not store authentication, try again."}, + } + } + }) +} + fn App(cx: Scope) -> Element { cx.render(rsx! { - Router { + Router { Route { to: Page::Home.to(), RegaladeSidebar { current: Page::Home, Index {} } } @@ -524,6 +590,7 @@ fn App(cx: Scope) -> Element { RegaladeSidebar { current: Page::RecipeList, recipe::RecipeView {} } } Route { to: "/login", Login {} } + Route { to: "/login/oidc", OidcRedirect {} } Route { to: "/household_selection", LoginRedirect { HouseholdSelection {} } } diff --git a/app/src/recipe/creator.rs b/app/src/recipe/creator.rs index 5607064..60708a4 100644 --- a/app/src/recipe/creator.rs +++ b/app/src/recipe/creator.rs @@ -341,15 +341,15 @@ pub fn RecipeCreator(cx: Scope) -> Element { } div { h2 { "Steps" } - div {class: "text-start", - textarea { - class: "form-control", - id: "steps-area", - value: "{steps}", - rows: "10", - oninput: move |e| steps.set(e.value.clone()) + div { class: "text-start", + textarea { + class: "form-control", + id: "steps-area", + value: "{steps}", + rows: "10", + oninput: move |e| steps.set(e.value.clone()) + } } -} } hr {} ModalToggleButton { class: "btn btn-lg btn-primary", modal_id: "newRcpModal", "Create Recipe" } diff --git a/app/src/recipe/view.rs b/app/src/recipe/view.rs index 66df07a..850edba 100644 --- a/app/src/recipe/view.rs +++ b/app/src/recipe/view.rs @@ -700,9 +700,7 @@ fn RecipeViewer(cx: Scope) -> Element { hr {} div { class: "text-start", h2 { "Steps" } - div { - dangerous_inner_html: "{steps_rendered}" - } + div { dangerous_inner_html: "{steps_rendered}" } EditSteps { recipe: cx.props.id, refresh: cx.props.refresh.clone(), diff --git a/migration/src/lib.rs b/migration/src/lib.rs index d811bc7..80bc5d0 100644 --- a/migration/src/lib.rs +++ b/migration/src/lib.rs @@ -5,6 +5,7 @@ mod m20230520_203638_household; mod m20230529_184433_ingredients; mod m20230618_163416_recipe; mod m20230629_151746_recipe_person_count; +mod m20230726_203858_oidc; pub struct Migrator; @@ -17,6 +18,7 @@ impl MigratorTrait for Migrator { Box::new(m20230529_184433_ingredients::Migration), Box::new(m20230618_163416_recipe::Migration), Box::new(m20230629_151746_recipe_person_count::Migration), + Box::new(m20230726_203858_oidc::Migration), ] } } diff --git a/migration/src/m20220101_000001_account.rs b/migration/src/m20220101_000001_account.rs index d5c67cf..8805156 100644 --- a/migration/src/m20220101_000001_account.rs +++ b/migration/src/m20220101_000001_account.rs @@ -9,6 +9,7 @@ pub(crate) enum User { Id, Name, Password, + OpenIdSubject, } #[async_trait::async_trait] diff --git a/migration/src/m20230726_203858_oidc.rs b/migration/src/m20230726_203858_oidc.rs new file mode 100644 index 0000000..3440c3e --- /dev/null +++ b/migration/src/m20230726_203858_oidc.rs @@ -0,0 +1,40 @@ +use sea_orm_migration::prelude::*; + +use crate::m20220101_000001_account::User; + +#[derive(DeriveMigrationName)] +pub struct Migration; + +#[async_trait::async_trait] +impl MigrationTrait for Migration { + async fn up(&self, manager: &SchemaManager) -> Result<(), DbErr> { + manager + .alter_table( + Table::alter() + .table(User::Table) + .modify_column(ColumnDef::new(User::Password).null()) + .add_column(ColumnDef::new(User::OpenIdSubject).string().null()) + .to_owned(), + ) + .await + } + + async fn down(&self, manager: &SchemaManager) -> Result<(), DbErr> { + manager.exec_stmt( + Query::delete() + .from_table(User::Table) + .cond_where(Expr::col(User::Password).is_null()) + .to_owned(), + ).await?; + + manager + .alter_table( + Table::alter() + .table(User::Table) + .drop_column(User::OpenIdSubject) + .modify_column(ColumnDef::new(User::Password).not_null()) + .to_owned(), + ) + .await + } +} diff --git a/src/entity/user.rs b/src/entity/user.rs index c3f901e..ae683d7 100644 --- a/src/entity/user.rs +++ b/src/entity/user.rs @@ -9,8 +9,9 @@ pub struct Model { pub id: Uuid, #[sea_orm(unique)] pub name: String, - #[sea_orm(column_type = "Binary(BlobSize::Blob(None))")] - pub password: Vec, + #[sea_orm(column_type = "Binary(BlobSize::Blob(None))", nullable)] + pub password: Option>, + pub open_id_subject: Option, } #[derive(Copy, Clone, Debug, EnumIter, DeriveRelation)] diff --git a/src/main.rs b/src/main.rs index 70faa04..e4b2714 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,13 +1,26 @@ -use std::{net::SocketAddr, path::PathBuf, sync::Arc}; +use std::{ + collections::{HashMap, VecDeque}, + net::SocketAddr, + path::PathBuf, + sync::Arc, +}; +use anyhow::anyhow; use axum::Router; use base64::{engine::general_purpose, Engine}; use jwt_simple::prelude::HS256Key; use migration::{Migrator, MigratorTrait}; +use openidconnect::{ + core::{CoreAuthenticationFlow, CoreClient, CoreProviderMetadata}, + url::Url, + AccessTokenHash, AuthorizationCode, ClientId, ClientSecret, CsrfToken, IssuerUrl, Nonce, + OAuth2TokenResponse, PkceCodeChallenge, PkceCodeVerifier, RedirectUrl, TokenResponse, +}; use sea_orm::{ConnectOptions, Database, DatabaseConnection}; use serde::{Deserialize, Deserializer, Serialize, Serializer}; use tower_http::services::{ServeDir, ServeFile}; use tracing_subscriber::EnvFilter; +use uuid::Uuid; pub(crate) mod entity; mod routes; @@ -65,6 +78,44 @@ impl<'de> Deserialize<'de> for Base64 { } } +fn deserialize_comma<'de, D>(de: D) -> Result, D::Error> +where + D: Deserializer<'de>, +{ + use serde::de::Visitor; + + struct CommaVisitor; + impl<'de> Visitor<'de> for CommaVisitor { + type Value = Vec; + + fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result { + formatter.write_str("string containg comma separated strings") + } + + fn visit_str(self, v: &str) -> Result + where + E: serde::de::Error, + { + Ok(v.split(',') + .map(|v| openidconnect::Scope::new(v.to_string())) + .collect()) + } + } + + de.deserialize_str(CommaVisitor) +} + +#[derive(Deserialize, Debug)] +#[serde(rename_all = "UPPERCASE")] +struct OpenidConnectSettings { + url: String, + id: ClientId, + secret: ClientSecret, + domain: String, + #[serde(deserialize_with = "deserialize_comma")] + scopes: Vec, +} + fn default_host() -> String { "0.0.0.0".into() } @@ -87,6 +138,8 @@ struct Settings { serve_app: Option, database_url: String, #[serde(default)] + oidc: Option, + #[serde(default)] sqlx_logging: bool, } @@ -102,6 +155,198 @@ impl Settings { struct AppState { jwt_secret: Base64, db: DatabaseConnection, + oidc: Option, +} + +struct OpenidConnector { + provider: CoreClient, + scopes: Vec, + domain: Url, + inflight: parking_lot::Mutex, +} + +struct FifoMap { + max: usize, + order: VecDeque, + values: HashMap, +} + +struct FifoMapInsert<'a> { + pub id: Uuid, + map: &'a mut FifoMap, +} + +impl<'a> FifoMapInsert<'a> { + pub fn insert(self, state: OpenidAuthState) { + let FifoMapInsert { id, map } = self; + + map.order.push_back(id); + if map.order.len() > map.max { + let first = map.order.pop_front().unwrap(); + map.values.remove(&first); + } + + map.values.insert(id, state); + } +} + +impl FifoMap { + pub fn new(max: usize) -> Self { + Self { + max, + order: VecDeque::new(), + values: HashMap::new(), + } + } + + pub fn new_entry(&mut self) -> FifoMapInsert { + let id = loop { + let id = Uuid::new_v4(); + if !self.values.contains_key(&id) { + break id; + } + }; + + FifoMapInsert { id, map: self } + } + + pub fn remove(&mut self, id: Uuid) -> Option { + let state = self.values.remove(&id); + + if let Some(idx) = self.order.iter().position(|&v| v == id) { + self.order.remove(idx); + }; + + state + } +} + +struct OpenidAuthState { + pkce_verifier: PkceCodeVerifier, + csrf_token: CsrfToken, + nonce: Nonce, + source_url: String, +} + +pub struct OpenidAccount { + pub sub: String, + pub name: String, + pub source_url: String, +} + +impl OpenidConnector { + async fn new(settings: OpenidConnectSettings) -> anyhow::Result { + let metadata = CoreProviderMetadata::discover_async( + IssuerUrl::new(settings.url)?, + openidconnect::reqwest::async_http_client, + ) + .await?; + + let provider = + CoreClient::from_provider_metadata(metadata, settings.id, Some(settings.secret)); + + Ok(Self { + provider, + scopes: settings.scopes, + domain: settings.domain.parse()?, + inflight: parking_lot::Mutex::new(FifoMap::new(1024)), + }) + } + + pub fn start_auth(&self, source_url: String) -> Url { + let (pkce_challenge, pkce_verifier) = PkceCodeChallenge::new_random_sha256(); + + let mut inflight = self.inflight.lock(); + + let slot = inflight.new_entry(); + + tracing::info!("Login flow with id {}", slot.id); + + let (url, csrf_token, nonce) = self + .provider + .authorize_url( + CoreAuthenticationFlow::AuthorizationCode, + CsrfToken::new_random, + Nonce::new_random, + ) + .add_scopes(self.scopes.iter().cloned()) + .set_pkce_challenge(pkce_challenge) + .set_redirect_uri(std::borrow::Cow::Owned(RedirectUrl::from_url( + self.domain + .join(&format!("api/login/redirect/{}", slot.id)) + .unwrap(), + ))) + .url(); + + slot.insert(OpenidAuthState { + pkce_verifier, + csrf_token, + nonce, + source_url, + }); + + url + } + + pub async fn redirected( + &self, + id: Uuid, + csrf_state: String, + code: String, + ) -> anyhow::Result { + let Some(state) = self.inflight.lock().remove(id) else { + anyhow::bail!("Redirect has expired or has already been submitted") + }; + + tracing::info!("Redirected to continue login flow for {id}"); + + if state.csrf_token.secret() != &csrf_state { + anyhow::bail!( + "CRSF state does not match, expected: {} got {csrf_state}", + state.csrf_token.secret() + ) + } + + let token_response = self + .provider + .exchange_code(AuthorizationCode::new(code)) + .set_pkce_verifier(state.pkce_verifier) + .set_redirect_uri(std::borrow::Cow::Owned(RedirectUrl::from_url( + self.domain + .join(&format!("api/login/redirect/{}", id)) + .unwrap(), + ))) + .request_async(openidconnect::reqwest::async_http_client) + .await?; + + let Some(id_token) = token_response.id_token() else { + anyhow::bail!("Server did not return an ID token") + }; + + let claims = id_token.claims(&self.provider.id_token_verifier(), &state.nonce)?; + + if let Some(expected_access_token_hash) = claims.access_token_hash() { + let actual_access_token_hash = AccessTokenHash::from_token( + token_response.access_token(), + &id_token.signing_alg()?, + )?; + + if actual_access_token_hash != *expected_access_token_hash { + anyhow::bail!("Invalid access token") + } + } + + Ok(OpenidAccount { + sub: claims.subject().to_string(), + name: claims + .name() + .and_then(|v| v.get(None)) + .map(|v| v.to_string()) + .or(claims.preferred_username().map(|v| v.to_string())) + .ok_or_else(|| anyhow!("No name or preferred_username"))?, + source_url: state.source_url, + }) + } } #[tokio::main] @@ -121,9 +366,15 @@ async fn main() -> anyhow::Result<()> { let mut opt = ConnectOptions::new(config.database_url); opt.sqlx_logging(config.sqlx_logging); + let oidc = match config.oidc { + None => None, + Some(settings) => Some(OpenidConnector::new(settings).await?), + }; + let state = Arc::new(AppState { jwt_secret: config.jwt_secret, db: Database::connect(opt).await?, + oidc, }); Migrator::up(&state.db, None).await?; @@ -131,7 +382,10 @@ async fn main() -> anyhow::Result<()> { let router = Router::new() .nest( "/api", - routes::router(config.api_allowed.map(|s| s.parse()).transpose()?), + routes::router( + config.api_allowed.map(|s| s.parse()).transpose()?, + state.oidc.is_some(), + ), ) .with_state(state); diff --git a/src/routes/mod.rs b/src/routes/mod.rs index e42d9f3..8b32fc9 100644 --- a/src/routes/mod.rs +++ b/src/routes/mod.rs @@ -3,19 +3,19 @@ use std::sync::Arc; use api::{LoginRequest, LoginResponse, UserInfo}; use axum::{ async_trait, - extract::{FromRef, FromRequestParts, Path, State}, + extract::{FromRef, FromRequestParts, Path, Query, State}, headers::{authorization::Bearer, Authorization}, http::{ header::{AUTHORIZATION, CONTENT_TYPE}, request::Parts, HeaderValue, Method, StatusCode, }, - response::IntoResponse, + response::{IntoResponse, Redirect}, routing::{delete, get, patch, post, put}, Json, Router, TypedHeader, }; use jwt_simple::prelude::*; -use sea_orm::{prelude::*, TransactionError}; +use sea_orm::{prelude::*, ActiveValue, TransactionError}; use sha2::{Digest, Sha512}; use tower_http::cors::{self, AllowOrigin, CorsLayer}; @@ -45,6 +45,8 @@ enum RouteError { RessourceNotFound, #[error("The request was malformed")] InvalidRequest(String), + #[error("A normal account with this name already exists")] + NormalAccount, #[error("Error in DB transaction")] TxnError(#[from] TransactionError>), } @@ -77,6 +79,9 @@ impl IntoResponse for RouteError { .into_response(), RouteError::RessourceNotFound => StatusCode::NOT_FOUND.into_response(), RouteError::InvalidRequest(reason) => (StatusCode::BAD_REQUEST, reason).into_response(), + s @ RouteError::NormalAccount => { + (StatusCode::BAD_REQUEST, s.to_string()).into_response() + } e => { tracing::error!("Internal error: {e:?}"); StatusCode::INTERNAL_SERVER_ERROR.into_response() @@ -138,12 +143,16 @@ async fn login( return Err(RouteError::UnknownAccount) }; + let Some(password) = user.password.as_ref() else { + return Err(RouteError::UnknownAccount); + }; + let mut hasher = Sha512::new(); hasher.update(user.id.as_bytes()); hasher.update(req.password.as_bytes()); let hash = hasher.finalize(); - if hash[..] != user.password { + if &hash[..] != password { return Err(RouteError::UnknownAccount); } @@ -155,6 +164,92 @@ async fn login( Ok(Json(LoginResponse { token })) } +#[derive(Deserialize)] +struct OidcStartParam { + r#return: String, +} + +async fn oidc_login( + State(state): State, + Query(start): Query, +) -> Result { + tracing::info!("Starting OIDC login"); + let oidc = state.oidc.as_ref().unwrap(); + + let redirect_url = oidc.start_auth(start.r#return); + + Ok(Redirect::to(redirect_url.as_str())) +} + +#[derive(Deserialize)] +struct OidcRedirectParams { + state: String, + code: String, +} + +async fn oidc_login_finish( + State(state): State, + Path(id): Path, + Query(redirect): Query, +) -> Result { + match state + .oidc + .as_ref() + .unwrap() + .redirected(id, redirect.state, redirect.code) + .await + { + Err(e) => { + tracing::error!("Error when finishing OAuth2 flow {e:?}"); + Err(RouteError::Unauthorized) + } + Ok(account) => { + let user = User::find() + .filter( + user::Column::Name + .eq(&account.name) + .or(user::Column::OpenIdSubject.eq(&account.sub)), + ) + .one(&state.db) + .await?; + + let user = match user { + None => { + let model = user::ActiveModel { + id: ActiveValue::Set(Uuid::new_v4()), + name: ActiveValue::Set(account.name), + password: ActiveValue::NotSet, + open_id_subject: ActiveValue::Set(Some(account.sub)), + }; + + model.insert(&state.db).await? + } + Some(user) => { + if user.open_id_subject.as_ref() != Some(&account.sub) { + return Err(RouteError::NormalAccount); + } + + user + } + }; + + let mut claims = Claims::create(Duration::from_secs(3600 * 24 * 31 * 6)); + claims.subject = Some(user.id.to_string()); + + let token = state.jwt_secret.0.authenticate(claims)?; + + let redirect = format!( + "{}?token={}&username={}", + account.source_url, + urlencoding::encode(&token), + urlencoding::encode(&user.name), + ); + + Ok(Redirect::to(&redirect)) + } + } +} + async fn get_user_id( _: AuthenticatedUser, State(state): State, @@ -173,7 +268,7 @@ async fn get_user_id( }))) } -pub(crate) fn router(api_allowed: Option) -> Router { +pub(crate) fn router(api_allowed: Option, has_oidc: bool) -> Router { let origin: AllowOrigin = match api_allowed { Some(n) => n.into(), None => cors::Any.into(), @@ -185,7 +280,7 @@ pub(crate) fn router(api_allowed: Option) -> Router { let mk_service = |m: Vec| cors_base.clone().allow_methods(m); - Router::new() + let router = Router::new() .route( "/search/user/:name", get(get_user_id).layer(mk_service(vec![Method::GET])), @@ -246,5 +341,18 @@ pub(crate) fn router(api_allowed: Option) -> Router { .delete(recipe::delete_ig) .put(recipe::add_ig_request) .layer(mk_service(vec![Method::PATCH, Method::DELETE, Method::PUT])), - ) + ); + + if has_oidc { + async fn unit() {} + router + .route("/login/oidc", get(oidc_login)) + .route( + "/login/has_oidc", + get(unit).layer(mk_service(vec![Method::GET])), + ) + .route("/login/redirect/:id", get(oidc_login_finish)) + } else { + router + } }