Allow to login with OIDC

This commit is contained in:
traxys 2023-07-27 00:06:36 +02:00
parent f556fec3bb
commit 7a2ff7ad1d
12 changed files with 782 additions and 29 deletions

288
Cargo.lock generated
View file

@ -739,6 +739,7 @@ dependencies = [
"ident_case",
"proc-macro2",
"quote",
"strsim",
"syn 2.0.16",
]
@ -1000,6 +1001,7 @@ dependencies = [
"log",
"pulldown-cmark",
"serde",
"urlencoding",
"uuid",
"wasm-bindgen",
]
@ -1030,6 +1032,12 @@ version = "0.15.7"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1aaf95b3e5c8f23aa320147307562d361db0ae0d51242340f558153b4eb2439b"
[[package]]
name = "dyn-clone"
version = "1.0.12"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "304e6508efa593091e97a9abbc10f90aa7ca635b6d2784feff3c89d41dd12272"
[[package]]
name = "ecdsa"
version = "0.16.7"
@ -1081,6 +1089,15 @@ dependencies = [
"zeroize",
]
[[package]]
name = "encoding_rs"
version = "0.8.32"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "071a31f4ee85403370b58aca746f01041ede6f0da2730960ad001edc2b71b394"
dependencies = [
"cfg-if",
]
[[package]]
name = "enumset"
version = "1.1.2"
@ -1410,8 +1427,10 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c85e1d9ab2eadba7e5040d4e09cbd6d072b76a557ad64e797c2cb9d4da21d7e4"
dependencies = [
"cfg-if",
"js-sys",
"libc",
"wasi 0.11.0+wasi-snapshot-preview1",
"wasm-bindgen",
]
[[package]]
@ -1509,6 +1528,25 @@ dependencies = [
"subtle",
]
[[package]]
name = "h2"
version = "0.3.20"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "97ec8491ebaf99c8eaa73058b045fe58073cd6be7f596ac993ced0b0a0c01049"
dependencies = [
"bytes",
"fnv",
"futures-core",
"futures-sink",
"futures-util",
"http",
"indexmap",
"slab",
"tokio",
"tokio-util",
"tracing",
]
[[package]]
name = "hashbrown"
version = "0.12.3"
@ -1715,6 +1753,7 @@ dependencies = [
"futures-channel",
"futures-core",
"futures-util",
"h2",
"http",
"http-body",
"httparse",
@ -1728,6 +1767,20 @@ dependencies = [
"want",
]
[[package]]
name = "hyper-rustls"
version = "0.24.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8d78e1e73ec14cf7375674f74d7dde185c8206fd9dea6fb6295e8a98098aaa97"
dependencies = [
"futures-util",
"http",
"hyper",
"rustls 0.21.5",
"tokio",
"tokio-rustls 0.24.1",
]
[[package]]
name = "iana-time-zone"
version = "0.1.56"
@ -1792,6 +1845,7 @@ checksum = "bd070e393353796e801d209ad339e89596eb4c8d430d18ede6a1cced8fafbd99"
dependencies = [
"autocfg",
"hashbrown 0.12.3",
"serde",
]
[[package]]
@ -1870,6 +1924,12 @@ dependencies = [
"windows-sys 0.48.0",
]
[[package]]
name = "ipnet"
version = "2.8.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "28b29a3cd74f0f4598934efe3aeba42bae0eb4680554128851ebbecb02af14e6"
[[package]]
name = "itertools"
version = "0.10.5"
@ -1921,7 +1981,7 @@ dependencies = [
"p256",
"p384",
"rand",
"rsa",
"rsa 0.7.2",
"serde",
"serde_json",
"spki 0.6.0",
@ -2275,12 +2335,63 @@ dependencies = [
"libc",
]
[[package]]
name = "oauth2"
version = "4.4.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "09a6e2a2b13a56ebeabba9142f911745be6456163fd6c3d361274ebcd891a80c"
dependencies = [
"base64 0.13.1",
"chrono",
"getrandom",
"http",
"rand",
"reqwest",
"serde",
"serde_json",
"serde_path_to_error",
"sha2",
"thiserror",
"url",
]
[[package]]
name = "once_cell"
version = "1.17.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b7e5500299e16ebb147ae15a00a942af264cf3688f47923b8fc2cd5858f23ad3"
[[package]]
name = "openidconnect"
version = "3.3.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "03335ade401352b354b017e7597ddb40040091da445b031bf659e597e032b1fc"
dependencies = [
"base64 0.13.1",
"chrono",
"dyn-clone",
"hmac",
"http",
"itertools 0.10.5",
"log",
"oauth2",
"p256",
"p384",
"rand",
"rsa 0.9.2",
"serde",
"serde-value",
"serde_derive",
"serde_json",
"serde_path_to_error",
"serde_plain",
"serde_with",
"sha2",
"subtle",
"thiserror",
"url",
]
[[package]]
name = "ordered-float"
version = "2.10.0"
@ -2515,6 +2626,17 @@ dependencies = [
"zeroize",
]
[[package]]
name = "pkcs1"
version = "0.7.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c8ffb9f10fa047879315e6625af03c164b16962a5368d724ed16323b68ace47f"
dependencies = [
"der 0.7.6",
"pkcs8 0.10.2",
"spki 0.7.2",
]
[[package]]
name = "pkcs8"
version = "0.9.0"
@ -2737,6 +2859,8 @@ dependencies = [
"envious",
"jwt-simple",
"migration",
"openidconnect",
"parking_lot 0.12.1",
"sea-orm",
"sea-query",
"serde",
@ -2746,6 +2870,7 @@ dependencies = [
"tower-http",
"tracing",
"tracing-subscriber",
"urlencoding",
"uuid",
]
@ -2790,6 +2915,45 @@ dependencies = [
"bytecheck",
]
[[package]]
name = "reqwest"
version = "0.11.18"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "cde824a14b7c14f85caff81225f411faacc04a2013f41670f41443742b1c1c55"
dependencies = [
"base64 0.21.0",
"bytes",
"encoding_rs",
"futures-core",
"futures-util",
"h2",
"http",
"http-body",
"hyper",
"hyper-rustls",
"ipnet",
"js-sys",
"log",
"mime",
"once_cell",
"percent-encoding",
"pin-project-lite",
"rustls 0.21.5",
"rustls-pemfile",
"serde",
"serde_json",
"serde_urlencoded",
"tokio",
"tokio-rustls 0.24.1",
"tower-service",
"url",
"wasm-bindgen",
"wasm-bindgen-futures",
"web-sys",
"webpki-roots",
"winreg",
]
[[package]]
name = "rfc6979"
version = "0.4.0"
@ -2855,7 +3019,7 @@ dependencies = [
"num-integer",
"num-iter",
"num-traits",
"pkcs1",
"pkcs1 0.4.1",
"pkcs8 0.9.0",
"rand_core",
"signature 1.6.4",
@ -2864,6 +3028,28 @@ dependencies = [
"zeroize",
]
[[package]]
name = "rsa"
version = "0.9.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6ab43bb47d23c1a631b4b680199a45255dce26fa9ab2fa902581f624ff13e6a8"
dependencies = [
"byteorder",
"const-oid",
"digest",
"num-bigint-dig",
"num-integer",
"num-iter",
"num-traits",
"pkcs1 0.7.5",
"pkcs8 0.10.2",
"rand_core",
"signature 2.1.0",
"spki 0.7.2",
"subtle",
"zeroize",
]
[[package]]
name = "rust_decimal"
version = "1.29.1"
@ -2923,6 +3109,18 @@ dependencies = [
"webpki",
]
[[package]]
name = "rustls"
version = "0.21.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "79ea77c539259495ce8ca47f53e66ae0330a8819f67e23ac96ca02f50e7b7d36"
dependencies = [
"log",
"ring",
"rustls-webpki",
"sct",
]
[[package]]
name = "rustls-pemfile"
version = "1.0.2"
@ -2932,6 +3130,16 @@ dependencies = [
"base64 0.21.0",
]
[[package]]
name = "rustls-webpki"
version = "0.101.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "15f36a6828982f422756984e47912a7a51dcbc2a197aa791158f8ca61cd8204e"
dependencies = [
"ring",
"untrusted",
]
[[package]]
name = "rustversion"
version = "1.0.12"
@ -3219,6 +3427,15 @@ dependencies = [
"serde",
]
[[package]]
name = "serde_plain"
version = "1.0.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d6018081315db179d0ce57b1fe4b62a12a0028c9cf9bbef868c9cf477b3c34ae"
dependencies = [
"serde",
]
[[package]]
name = "serde_repr"
version = "0.1.12"
@ -3242,6 +3459,34 @@ dependencies = [
"serde",
]
[[package]]
name = "serde_with"
version = "3.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "21e47d95bc83ed33b2ecf84f4187ad1ab9685d18ff28db000c99deac8ce180e3"
dependencies = [
"base64 0.21.0",
"chrono",
"hex",
"indexmap",
"serde",
"serde_json",
"serde_with_macros",
"time 0.3.21",
]
[[package]]
name = "serde_with_macros"
version = "3.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ea3cee93715c2e266b9338b7544da68a9f24e227722ba482bd1c024367c77c65"
dependencies = [
"darling",
"proc-macro2",
"quote",
"syn 2.0.16",
]
[[package]]
name = "sha1"
version = "0.10.5"
@ -3486,7 +3731,7 @@ dependencies = [
"percent-encoding",
"rand",
"rust_decimal",
"rustls",
"rustls 0.20.8",
"rustls-pemfile",
"serde",
"serde_json",
@ -3532,7 +3777,7 @@ checksum = "804d3f245f894e61b1e6263c84b23ca675d96753b5abfd5cc8597d86806e8024"
dependencies = [
"once_cell",
"tokio",
"tokio-rustls",
"tokio-rustls 0.23.4",
]
[[package]]
@ -3571,6 +3816,12 @@ dependencies = [
"unicode-normalization",
]
[[package]]
name = "strsim"
version = "0.10.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "73473c0e59e6d5812c5dfe2a064a6444949f089e20eec9a2e5506596494e4623"
[[package]]
name = "subtle"
version = "2.4.1"
@ -3753,11 +4004,21 @@ version = "0.23.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c43ee83903113e03984cb9e5cebe6c04a5116269e900e3ddba8f068a62adda59"
dependencies = [
"rustls",
"rustls 0.20.8",
"tokio",
"webpki",
]
[[package]]
name = "tokio-rustls"
version = "0.24.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c28327cf380ac148141087fbfb9de9d7bd4e84ab5d2c28fbc911d753de8a7081"
dependencies = [
"rustls 0.21.5",
"tokio",
]
[[package]]
name = "tokio-stream"
version = "0.1.14"
@ -3780,6 +4041,7 @@ dependencies = [
"futures-sink",
"pin-project-lite",
"tokio",
"tracing",
]
[[package]]
@ -3982,8 +4244,15 @@ dependencies = [
"form_urlencoded",
"idna",
"percent-encoding",
"serde",
]
[[package]]
name = "urlencoding"
version = "2.1.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "daf8dba3b7eb870caf1ddeed7bc9d2a049f3cfdfae7cb521b087cc33ae4c49da"
[[package]]
name = "utf-8"
version = "0.7.6"
@ -4360,6 +4629,15 @@ version = "0.48.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1a515f5799fe4961cb532f983ce2b23082366b898e52ffbce459c86f67c8378a"
[[package]]
name = "winreg"
version = "0.10.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "80d0f4e272c85def139476380b12f9ac60926689dd2e01d4923222f40580869d"
dependencies = [
"winapi",
]
[[package]]
name = "wyz"
version = "0.5.1"

View file

@ -23,7 +23,10 @@ tower-http = { version = "0.4.0", features = ["cors", "fs"] }
sha2 = "0.10"
uuid = { version = "1.3", features = ["v4"] }
sea-query = "0.28"
openidconnect = "3.3.0"
envious = "0.2.2"
parking_lot = "0.12.1"
urlencoding = "2.1.3"
[dependencies.sea-orm]
version = "0.11"

View file

@ -21,5 +21,6 @@ itertools = "0.11.0"
log = "0.4.19"
pulldown-cmark = "0.9.3"
serde = { version = "1.0.164", features = ["derive"] }
urlencoding = "2.1.3"
uuid = "1.4.0"
wasm-bindgen = "0.2.87"

View file

@ -6,7 +6,7 @@ use api::{
LoginResponse, UserInfo,
};
use dioxus::prelude::*;
use dioxus_router::{use_router, Route, Router};
use dioxus_router::{use_route, use_router, Redirect, Route, Router};
use gloo_storage::{errors::StorageError, LocalStorage, Storage};
use itertools::Itertools;
use serde::{Deserialize, Serialize};
@ -33,6 +33,11 @@ const API_ROUTE: &str = match option_env!("REGALADE_API_SERVER_BASE") {
Some(v) => v,
};
const FRONTEND_ROOT: &str = match option_env!("REGALADE_FRONTEND_DOMAIN") {
None => "http://localhost:8080",
Some(v) => v,
};
#[macro_export]
macro_rules! api {
($($arg:tt)*) => {{
@ -204,6 +209,37 @@ async fn do_login(username: String, password: String) -> anyhow::Result<()> {
Ok(())
}
async fn check_oidc() -> anyhow::Result<bool> {
let rsp = gloo_net::http::Request::get(api!("login/has_oidc"))
.send()
.await?;
Ok(rsp.status() == 200)
}
fn Openid(cx: Scope) -> Element {
let has = use_future(cx, (), |()| check_oidc());
cx.render(match has.value().unwrap_or(&Ok(false)) {
Ok(true) => {
let route = api!("login/oidc").to_owned();
let ret = urlencoding::encode(&format!("{FRONTEND_ROOT}/login/oidc")).to_string();
rsx! {
a {
href: "{route}?return={ret}",
class: "mt-1 w-100 btn btn-lg btn-primary",
"Login with OpenID"
}
}
}
Ok(false) => rsx! {{}},
Err(e) => {
log::error!("Could not check oidc status: {e:?}");
rsx! {{}}
}
})
}
fn Login(cx: Scope) -> Element {
let error = use_state(cx, || None::<String>);
let router = use_router(cx);
@ -256,6 +292,7 @@ fn Login(cx: Scope) -> Element {
label { "for": "floatingPass", "Password" }
}
button { class: "w-100 btn btn-lg btn-primary", "type": "submit", "Login" }
Openid {}
}
})
}
@ -505,6 +542,35 @@ fn Index(cx: Scope) -> Element {
cx.render(rsx! {"INDEX"})
}
#[derive(Deserialize)]
struct OidcQuery {
token: String,
username: String,
}
fn OidcRedirect(cx: Scope) -> Element {
let auth = use_route(cx).query::<OidcQuery>();
cx.render(match auth {
None => rsx! {"No authentication query, internal error."},
Some(v) => {
match LocalStorage::set(
"token",
LoginInfo {
token: v.token,
name: v.username,
},
) {
Ok(_) => {
gloo_utils::window().location().replace("/").unwrap();
rsx! {{}}
}
Err(_) => rsx! {"Could not store authentication, try again."},
}
}
})
}
fn App(cx: Scope) -> Element {
cx.render(rsx! {
Router {
@ -524,6 +590,7 @@ fn App(cx: Scope) -> Element {
RegaladeSidebar { current: Page::RecipeList, recipe::RecipeView {} }
}
Route { to: "/login", Login {} }
Route { to: "/login/oidc", OidcRedirect {} }
Route { to: "/household_selection",
LoginRedirect { HouseholdSelection {} }
}

View file

@ -341,7 +341,7 @@ pub fn RecipeCreator(cx: Scope) -> Element {
}
div {
h2 { "Steps" }
div {class: "text-start",
div { class: "text-start",
textarea {
class: "form-control",
id: "steps-area",
@ -349,7 +349,7 @@ pub fn RecipeCreator(cx: Scope) -> Element {
rows: "10",
oninput: move |e| steps.set(e.value.clone())
}
}
}
}
hr {}
ModalToggleButton { class: "btn btn-lg btn-primary", modal_id: "newRcpModal", "Create Recipe" }

View file

@ -700,9 +700,7 @@ fn RecipeViewer(cx: Scope<RecipeViewerProps>) -> Element {
hr {}
div { class: "text-start",
h2 { "Steps" }
div {
dangerous_inner_html: "{steps_rendered}"
}
div { dangerous_inner_html: "{steps_rendered}" }
EditSteps {
recipe: cx.props.id,
refresh: cx.props.refresh.clone(),

View file

@ -5,6 +5,7 @@ mod m20230520_203638_household;
mod m20230529_184433_ingredients;
mod m20230618_163416_recipe;
mod m20230629_151746_recipe_person_count;
mod m20230726_203858_oidc;
pub struct Migrator;
@ -17,6 +18,7 @@ impl MigratorTrait for Migrator {
Box::new(m20230529_184433_ingredients::Migration),
Box::new(m20230618_163416_recipe::Migration),
Box::new(m20230629_151746_recipe_person_count::Migration),
Box::new(m20230726_203858_oidc::Migration),
]
}
}

View file

@ -9,6 +9,7 @@ pub(crate) enum User {
Id,
Name,
Password,
OpenIdSubject,
}
#[async_trait::async_trait]

View file

@ -0,0 +1,40 @@
use sea_orm_migration::prelude::*;
use crate::m20220101_000001_account::User;
#[derive(DeriveMigrationName)]
pub struct Migration;
#[async_trait::async_trait]
impl MigrationTrait for Migration {
async fn up(&self, manager: &SchemaManager) -> Result<(), DbErr> {
manager
.alter_table(
Table::alter()
.table(User::Table)
.modify_column(ColumnDef::new(User::Password).null())
.add_column(ColumnDef::new(User::OpenIdSubject).string().null())
.to_owned(),
)
.await
}
async fn down(&self, manager: &SchemaManager) -> Result<(), DbErr> {
manager.exec_stmt(
Query::delete()
.from_table(User::Table)
.cond_where(Expr::col(User::Password).is_null())
.to_owned(),
).await?;
manager
.alter_table(
Table::alter()
.table(User::Table)
.drop_column(User::OpenIdSubject)
.modify_column(ColumnDef::new(User::Password).not_null())
.to_owned(),
)
.await
}
}

View file

@ -9,8 +9,9 @@ pub struct Model {
pub id: Uuid,
#[sea_orm(unique)]
pub name: String,
#[sea_orm(column_type = "Binary(BlobSize::Blob(None))")]
pub password: Vec<u8>,
#[sea_orm(column_type = "Binary(BlobSize::Blob(None))", nullable)]
pub password: Option<Vec<u8>>,
pub open_id_subject: Option<String>,
}
#[derive(Copy, Clone, Debug, EnumIter, DeriveRelation)]

View file

@ -1,13 +1,26 @@
use std::{net::SocketAddr, path::PathBuf, sync::Arc};
use std::{
collections::{HashMap, VecDeque},
net::SocketAddr,
path::PathBuf,
sync::Arc,
};
use anyhow::anyhow;
use axum::Router;
use base64::{engine::general_purpose, Engine};
use jwt_simple::prelude::HS256Key;
use migration::{Migrator, MigratorTrait};
use openidconnect::{
core::{CoreAuthenticationFlow, CoreClient, CoreProviderMetadata},
url::Url,
AccessTokenHash, AuthorizationCode, ClientId, ClientSecret, CsrfToken, IssuerUrl, Nonce,
OAuth2TokenResponse, PkceCodeChallenge, PkceCodeVerifier, RedirectUrl, TokenResponse,
};
use sea_orm::{ConnectOptions, Database, DatabaseConnection};
use serde::{Deserialize, Deserializer, Serialize, Serializer};
use tower_http::services::{ServeDir, ServeFile};
use tracing_subscriber::EnvFilter;
use uuid::Uuid;
pub(crate) mod entity;
mod routes;
@ -65,6 +78,44 @@ impl<'de> Deserialize<'de> for Base64 {
}
}
fn deserialize_comma<'de, D>(de: D) -> Result<Vec<openidconnect::Scope>, D::Error>
where
D: Deserializer<'de>,
{
use serde::de::Visitor;
struct CommaVisitor;
impl<'de> Visitor<'de> for CommaVisitor {
type Value = Vec<openidconnect::Scope>;
fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
formatter.write_str("string containg comma separated strings")
}
fn visit_str<E>(self, v: &str) -> Result<Self::Value, E>
where
E: serde::de::Error,
{
Ok(v.split(',')
.map(|v| openidconnect::Scope::new(v.to_string()))
.collect())
}
}
de.deserialize_str(CommaVisitor)
}
#[derive(Deserialize, Debug)]
#[serde(rename_all = "UPPERCASE")]
struct OpenidConnectSettings {
url: String,
id: ClientId,
secret: ClientSecret,
domain: String,
#[serde(deserialize_with = "deserialize_comma")]
scopes: Vec<openidconnect::Scope>,
}
fn default_host() -> String {
"0.0.0.0".into()
}
@ -87,6 +138,8 @@ struct Settings {
serve_app: Option<PathBuf>,
database_url: String,
#[serde(default)]
oidc: Option<OpenidConnectSettings>,
#[serde(default)]
sqlx_logging: bool,
}
@ -102,6 +155,198 @@ impl Settings {
struct AppState {
jwt_secret: Base64,
db: DatabaseConnection,
oidc: Option<OpenidConnector>,
}
struct OpenidConnector {
provider: CoreClient,
scopes: Vec<openidconnect::Scope>,
domain: Url,
inflight: parking_lot::Mutex<FifoMap>,
}
struct FifoMap {
max: usize,
order: VecDeque<Uuid>,
values: HashMap<Uuid, OpenidAuthState>,
}
struct FifoMapInsert<'a> {
pub id: Uuid,
map: &'a mut FifoMap,
}
impl<'a> FifoMapInsert<'a> {
pub fn insert(self, state: OpenidAuthState) {
let FifoMapInsert { id, map } = self;
map.order.push_back(id);
if map.order.len() > map.max {
let first = map.order.pop_front().unwrap();
map.values.remove(&first);
}
map.values.insert(id, state);
}
}
impl FifoMap {
pub fn new(max: usize) -> Self {
Self {
max,
order: VecDeque::new(),
values: HashMap::new(),
}
}
pub fn new_entry(&mut self) -> FifoMapInsert {
let id = loop {
let id = Uuid::new_v4();
if !self.values.contains_key(&id) {
break id;
}
};
FifoMapInsert { id, map: self }
}
pub fn remove(&mut self, id: Uuid) -> Option<OpenidAuthState> {
let state = self.values.remove(&id);
if let Some(idx) = self.order.iter().position(|&v| v == id) {
self.order.remove(idx);
};
state
}
}
struct OpenidAuthState {
pkce_verifier: PkceCodeVerifier,
csrf_token: CsrfToken,
nonce: Nonce,
source_url: String,
}
pub struct OpenidAccount {
pub sub: String,
pub name: String,
pub source_url: String,
}
impl OpenidConnector {
async fn new(settings: OpenidConnectSettings) -> anyhow::Result<Self> {
let metadata = CoreProviderMetadata::discover_async(
IssuerUrl::new(settings.url)?,
openidconnect::reqwest::async_http_client,
)
.await?;
let provider =
CoreClient::from_provider_metadata(metadata, settings.id, Some(settings.secret));
Ok(Self {
provider,
scopes: settings.scopes,
domain: settings.domain.parse()?,
inflight: parking_lot::Mutex::new(FifoMap::new(1024)),
})
}
pub fn start_auth(&self, source_url: String) -> Url {
let (pkce_challenge, pkce_verifier) = PkceCodeChallenge::new_random_sha256();
let mut inflight = self.inflight.lock();
let slot = inflight.new_entry();
tracing::info!("Login flow with id {}", slot.id);
let (url, csrf_token, nonce) = self
.provider
.authorize_url(
CoreAuthenticationFlow::AuthorizationCode,
CsrfToken::new_random,
Nonce::new_random,
)
.add_scopes(self.scopes.iter().cloned())
.set_pkce_challenge(pkce_challenge)
.set_redirect_uri(std::borrow::Cow::Owned(RedirectUrl::from_url(
self.domain
.join(&format!("api/login/redirect/{}", slot.id))
.unwrap(),
)))
.url();
slot.insert(OpenidAuthState {
pkce_verifier,
csrf_token,
nonce,
source_url,
});
url
}
pub async fn redirected(
&self,
id: Uuid,
csrf_state: String,
code: String,
) -> anyhow::Result<OpenidAccount> {
let Some(state) = self.inflight.lock().remove(id) else {
anyhow::bail!("Redirect has expired or has already been submitted")
};
tracing::info!("Redirected to continue login flow for {id}");
if state.csrf_token.secret() != &csrf_state {
anyhow::bail!(
"CRSF state does not match, expected: {} got {csrf_state}",
state.csrf_token.secret()
)
}
let token_response = self
.provider
.exchange_code(AuthorizationCode::new(code))
.set_pkce_verifier(state.pkce_verifier)
.set_redirect_uri(std::borrow::Cow::Owned(RedirectUrl::from_url(
self.domain
.join(&format!("api/login/redirect/{}", id))
.unwrap(),
)))
.request_async(openidconnect::reqwest::async_http_client)
.await?;
let Some(id_token) = token_response.id_token() else {
anyhow::bail!("Server did not return an ID token")
};
let claims = id_token.claims(&self.provider.id_token_verifier(), &state.nonce)?;
if let Some(expected_access_token_hash) = claims.access_token_hash() {
let actual_access_token_hash = AccessTokenHash::from_token(
token_response.access_token(),
&id_token.signing_alg()?,
)?;
if actual_access_token_hash != *expected_access_token_hash {
anyhow::bail!("Invalid access token")
}
}
Ok(OpenidAccount {
sub: claims.subject().to_string(),
name: claims
.name()
.and_then(|v| v.get(None))
.map(|v| v.to_string())
.or(claims.preferred_username().map(|v| v.to_string()))
.ok_or_else(|| anyhow!("No name or preferred_username"))?,
source_url: state.source_url,
})
}
}
#[tokio::main]
@ -121,9 +366,15 @@ async fn main() -> anyhow::Result<()> {
let mut opt = ConnectOptions::new(config.database_url);
opt.sqlx_logging(config.sqlx_logging);
let oidc = match config.oidc {
None => None,
Some(settings) => Some(OpenidConnector::new(settings).await?),
};
let state = Arc::new(AppState {
jwt_secret: config.jwt_secret,
db: Database::connect(opt).await?,
oidc,
});
Migrator::up(&state.db, None).await?;
@ -131,7 +382,10 @@ async fn main() -> anyhow::Result<()> {
let router = Router::new()
.nest(
"/api",
routes::router(config.api_allowed.map(|s| s.parse()).transpose()?),
routes::router(
config.api_allowed.map(|s| s.parse()).transpose()?,
state.oidc.is_some(),
),
)
.with_state(state);

View file

@ -3,19 +3,19 @@ use std::sync::Arc;
use api::{LoginRequest, LoginResponse, UserInfo};
use axum::{
async_trait,
extract::{FromRef, FromRequestParts, Path, State},
extract::{FromRef, FromRequestParts, Path, Query, State},
headers::{authorization::Bearer, Authorization},
http::{
header::{AUTHORIZATION, CONTENT_TYPE},
request::Parts,
HeaderValue, Method, StatusCode,
},
response::IntoResponse,
response::{IntoResponse, Redirect},
routing::{delete, get, patch, post, put},
Json, Router, TypedHeader,
};
use jwt_simple::prelude::*;
use sea_orm::{prelude::*, TransactionError};
use sea_orm::{prelude::*, ActiveValue, TransactionError};
use sha2::{Digest, Sha512};
use tower_http::cors::{self, AllowOrigin, CorsLayer};
@ -45,6 +45,8 @@ enum RouteError {
RessourceNotFound,
#[error("The request was malformed")]
InvalidRequest(String),
#[error("A normal account with this name already exists")]
NormalAccount,
#[error("Error in DB transaction")]
TxnError(#[from] TransactionError<Box<RouteError>>),
}
@ -77,6 +79,9 @@ impl IntoResponse for RouteError {
.into_response(),
RouteError::RessourceNotFound => StatusCode::NOT_FOUND.into_response(),
RouteError::InvalidRequest(reason) => (StatusCode::BAD_REQUEST, reason).into_response(),
s @ RouteError::NormalAccount => {
(StatusCode::BAD_REQUEST, s.to_string()).into_response()
}
e => {
tracing::error!("Internal error: {e:?}");
StatusCode::INTERNAL_SERVER_ERROR.into_response()
@ -138,12 +143,16 @@ async fn login(
return Err(RouteError::UnknownAccount)
};
let Some(password) = user.password.as_ref() else {
return Err(RouteError::UnknownAccount);
};
let mut hasher = Sha512::new();
hasher.update(user.id.as_bytes());
hasher.update(req.password.as_bytes());
let hash = hasher.finalize();
if hash[..] != user.password {
if &hash[..] != password {
return Err(RouteError::UnknownAccount);
}
@ -155,6 +164,92 @@ async fn login(
Ok(Json(LoginResponse { token }))
}
#[derive(Deserialize)]
struct OidcStartParam {
r#return: String,
}
async fn oidc_login(
State(state): State<AppState>,
Query(start): Query<OidcStartParam>,
) -> Result<Redirect, RouteError> {
tracing::info!("Starting OIDC login");
let oidc = state.oidc.as_ref().unwrap();
let redirect_url = oidc.start_auth(start.r#return);
Ok(Redirect::to(redirect_url.as_str()))
}
#[derive(Deserialize)]
struct OidcRedirectParams {
state: String,
code: String,
}
async fn oidc_login_finish(
State(state): State<AppState>,
Path(id): Path<Uuid>,
Query(redirect): Query<OidcRedirectParams>,
) -> Result<Redirect, RouteError> {
match state
.oidc
.as_ref()
.unwrap()
.redirected(id, redirect.state, redirect.code)
.await
{
Err(e) => {
tracing::error!("Error when finishing OAuth2 flow {e:?}");
Err(RouteError::Unauthorized)
}
Ok(account) => {
let user = User::find()
.filter(
user::Column::Name
.eq(&account.name)
.or(user::Column::OpenIdSubject.eq(&account.sub)),
)
.one(&state.db)
.await?;
let user = match user {
None => {
let model = user::ActiveModel {
id: ActiveValue::Set(Uuid::new_v4()),
name: ActiveValue::Set(account.name),
password: ActiveValue::NotSet,
open_id_subject: ActiveValue::Set(Some(account.sub)),
};
model.insert(&state.db).await?
}
Some(user) => {
if user.open_id_subject.as_ref() != Some(&account.sub) {
return Err(RouteError::NormalAccount);
}
user
}
};
let mut claims = Claims::create(Duration::from_secs(3600 * 24 * 31 * 6));
claims.subject = Some(user.id.to_string());
let token = state.jwt_secret.0.authenticate(claims)?;
let redirect = format!(
"{}?token={}&username={}",
account.source_url,
urlencoding::encode(&token),
urlencoding::encode(&user.name),
);
Ok(Redirect::to(&redirect))
}
}
}
async fn get_user_id(
_: AuthenticatedUser,
State(state): State<AppState>,
@ -173,7 +268,7 @@ async fn get_user_id(
})))
}
pub(crate) fn router(api_allowed: Option<HeaderValue>) -> Router<AppState> {
pub(crate) fn router(api_allowed: Option<HeaderValue>, has_oidc: bool) -> Router<AppState> {
let origin: AllowOrigin = match api_allowed {
Some(n) => n.into(),
None => cors::Any.into(),
@ -185,7 +280,7 @@ pub(crate) fn router(api_allowed: Option<HeaderValue>) -> Router<AppState> {
let mk_service = |m: Vec<Method>| cors_base.clone().allow_methods(m);
Router::new()
let router = Router::new()
.route(
"/search/user/:name",
get(get_user_id).layer(mk_service(vec![Method::GET])),
@ -246,5 +341,18 @@ pub(crate) fn router(api_allowed: Option<HeaderValue>) -> Router<AppState> {
.delete(recipe::delete_ig)
.put(recipe::add_ig_request)
.layer(mk_service(vec![Method::PATCH, Method::DELETE, Method::PUT])),
);
if has_oidc {
async fn unit() {}
router
.route("/login/oidc", get(oidc_login))
.route(
"/login/has_oidc",
get(unit).layer(mk_service(vec![Method::GET])),
)
.route("/login/redirect/:id", get(oidc_login_finish))
} else {
router
}
}