Implement OIDC flow

This commit is contained in:
traxys 2023-08-15 15:57:19 +02:00
parent c7e3468acc
commit 7bf18486bc
4 changed files with 1057 additions and 10 deletions

746
Cargo.lock generated

File diff suppressed because it is too large Load diff

View file

@ -5,14 +5,18 @@ authors = ["traxys <quentin@familleboyer.net>"]
edition = "2021"
[dependencies]
axum = "0.6.20"
axum = { version = "0.6.20", features = ["query"] }
color-eyre = "0.6.2"
envious = "0.2.2"
openidconnect = "3.3.0"
parking_lot = "0.12.1"
serde = { version = "1.0.183", features = ["derive"] }
sqlx = { version = "0.7.1", features = ["runtime-tokio", "postgres", "uuid", "migrate"] }
thiserror = "1.0.44"
tokio = { version = "1.31.0", features = ["full"] }
tracing = "0.1.37"
tracing-subscriber = { version = "0.3.17", features = ["env-filter"] }
uuid = { version = "1.4.1", features = ["v4", "serde"] }
[profile.dev.package.sqlx-macros]
opt-level = 3

View file

@ -0,0 +1,2 @@
-- Add migration script here
ALTER TABLE accounts ADD COLUMN sub text UNIQUE NOT NULL;

View file

@ -1,9 +1,27 @@
use std::{net::SocketAddr, sync::Arc};
use std::{
collections::{HashMap, VecDeque},
net::SocketAddr,
sync::Arc,
};
use axum::Router;
use serde::Deserialize;
use axum::{
extract::{Path, Query, State},
http::StatusCode,
response::{IntoResponse, Redirect},
routing::get,
Router,
};
use color_eyre::eyre;
use openidconnect::{
core::{CoreAuthenticationFlow, CoreClient, CoreProviderMetadata},
url::Url,
AccessTokenHash, AuthorizationCode, ClientId, ClientSecret, CsrfToken, IssuerUrl, Nonce,
OAuth2TokenResponse, PkceCodeChallenge, PkceCodeVerifier, RedirectUrl, TokenResponse,
};
use serde::{Deserialize, Deserializer};
use sqlx::{postgres::PgPoolOptions, PgPool};
use tracing_subscriber::EnvFilter;
use uuid::Uuid;
fn default_port() -> u16 {
8080
@ -13,6 +31,33 @@ fn default_address() -> String {
"127.0.0.1".into()
}
fn deserialize_comma<'de, D>(de: D) -> Result<Vec<openidconnect::Scope>, D::Error>
where
D: Deserializer<'de>,
{
use serde::de::Visitor;
struct CommaVisitor;
impl<'de> Visitor<'de> for CommaVisitor {
type Value = Vec<openidconnect::Scope>;
fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
formatter.write_str("string containg comma separated strings")
}
fn visit_str<E>(self, v: &str) -> Result<Self::Value, E>
where
E: serde::de::Error,
{
Ok(v.split(',')
.map(|v| openidconnect::Scope::new(v.to_string()))
.collect())
}
}
de.deserialize_str(CommaVisitor)
}
#[derive(Deserialize, Debug)]
#[serde(rename_all = "UPPERCASE")]
struct Settings {
@ -21,6 +66,14 @@ struct Settings {
#[serde(default = "default_address")]
address: String,
database_url: String,
domain: String,
oidc_endpoint: String,
client_id: ClientId,
client_secret: ClientSecret,
#[serde(deserialize_with = "deserialize_comma")]
scopes: Vec<openidconnect::Scope>,
}
impl Settings {
@ -33,8 +86,255 @@ impl Settings {
}
}
struct OpenidConnector {
provider: CoreClient,
scopes: Vec<openidconnect::Scope>,
redirect_base: Url,
inflight: parking_lot::Mutex<FifoMap>,
}
struct OpenidAuthState {
pkce_verifier: PkceCodeVerifier,
csrf_token: CsrfToken,
nonce: Nonce,
}
struct FifoMap {
max: usize,
order: VecDeque<Uuid>,
values: HashMap<Uuid, OpenidAuthState>,
}
struct FifoMapInsert<'a> {
pub id: Uuid,
map: &'a mut FifoMap,
}
impl<'a> FifoMapInsert<'a> {
pub fn insert(self, state: OpenidAuthState) {
let FifoMapInsert { id, map } = self;
map.order.push_back(id);
if map.order.len() > map.max {
let first = map.order.pop_front().unwrap();
map.values.remove(&first);
}
map.values.insert(id, state);
}
}
impl FifoMap {
pub fn new(max: usize) -> Self {
Self {
max,
order: VecDeque::new(),
values: HashMap::new(),
}
}
pub fn new_entry(&mut self) -> FifoMapInsert {
let id = loop {
let id = Uuid::new_v4();
if !self.values.contains_key(&id) {
break id;
}
};
FifoMapInsert { id, map: self }
}
pub fn remove(&mut self, id: Uuid) -> Option<OpenidAuthState> {
let state = self.values.remove(&id);
if let Some(idx) = self.order.iter().position(|&v| v == id) {
self.order.remove(idx);
};
state
}
}
impl OpenidConnector {
async fn new(settings: &Settings) -> color_eyre::Result<Self> {
let metadata = CoreProviderMetadata::discover_async(
IssuerUrl::new(settings.oidc_endpoint.clone())?,
openidconnect::reqwest::async_http_client,
)
.await?;
let provider = CoreClient::from_provider_metadata(
metadata,
settings.client_id.clone(),
Some(settings.client_secret.clone()),
);
Ok(Self {
provider,
scopes: settings.scopes.clone(),
redirect_base: (settings.domain.clone() + "/login/redirect/").parse()?,
inflight: parking_lot::Mutex::new(FifoMap::new(1024)),
})
}
pub fn start_auth(&self) -> Url {
let (pkce_challenge, pkce_verifier) = PkceCodeChallenge::new_random_sha256();
let mut inflight = self.inflight.lock();
let slot = inflight.new_entry();
tracing::info!("Login flow with id {}", slot.id);
let (url, csrf_token, nonce) = self
.provider
.authorize_url(
CoreAuthenticationFlow::AuthorizationCode,
CsrfToken::new_random,
Nonce::new_random,
)
.add_scopes(self.scopes.iter().cloned())
.set_pkce_challenge(pkce_challenge)
.set_redirect_uri(std::borrow::Cow::Owned(RedirectUrl::from_url(
self.redirect_base.join(&slot.id.to_string()).unwrap(),
)))
.url();
slot.insert(OpenidAuthState {
pkce_verifier,
csrf_token,
nonce,
});
url
}
pub async fn redirected(
&self,
id: Uuid,
csrf_state: String,
code: String,
) -> color_eyre::Result<String> {
let Some(state) = self.inflight.lock().remove(id) else {
eyre::bail!("Redirect has expired or has already been submitted")
};
tracing::info!("Redirected to continue login flow for {id}");
if state.csrf_token.secret() != &csrf_state {
eyre::bail!(
"CRSF state does not match, expected: {} got {csrf_state}",
state.csrf_token.secret()
)
}
let token_response = self
.provider
.exchange_code(AuthorizationCode::new(code))
.set_pkce_verifier(state.pkce_verifier)
.set_redirect_uri(std::borrow::Cow::Owned(RedirectUrl::from_url(
self.redirect_base.join(&id.to_string()).unwrap(),
)))
.request_async(openidconnect::reqwest::async_http_client)
.await?;
let Some(id_token) = token_response.id_token() else {
eyre::bail!("Server did not return an ID token")
};
let claims = id_token.claims(&self.provider.id_token_verifier(), &state.nonce)?;
if let Some(expected_access_token_hash) = claims.access_token_hash() {
let actual_access_token_hash = AccessTokenHash::from_token(
token_response.access_token(),
&id_token.signing_alg()?,
)?;
if actual_access_token_hash != *expected_access_token_hash {
eyre::bail!("Invalid access token")
}
}
Ok(claims.subject().to_string())
}
}
struct AppState {
db: PgPool,
oidc: OpenidConnector,
}
#[derive(thiserror::Error, Debug)]
enum Error {
#[error("An error occured in the database")]
DbError(#[from] sqlx::Error),
#[error("An internal error occured")]
InternalError,
}
struct InternalError;
impl IntoResponse for InternalError {
fn into_response(self) -> axum::response::Response {
(StatusCode::INTERNAL_SERVER_ERROR, "Internal Error").into_response()
}
}
impl IntoResponse for Error {
fn into_response(self) -> axum::response::Response {
match self {
Error::InternalError => InternalError.into_response(),
Error::DbError(e) => {
tracing::error!("Database error: {e:?}");
InternalError.into_response()
}
}
}
}
async fn login(state: State<Arc<AppState>>) -> Redirect {
let redirect = state.oidc.start_auth();
Redirect::to(redirect.as_str())
}
#[derive(Deserialize)]
struct OidcRedirectParams {
state: String,
code: String,
}
async fn redirected(
state: State<Arc<AppState>>,
Path(id): Path<Uuid>,
Query(redirect): Query<OidcRedirectParams>,
) -> Result<(), Error> {
match state
.oidc
.redirected(id, redirect.state, redirect.code)
.await
{
Ok(sub) => {
let account = sqlx::query!("SELECT id FROM accounts WHERE sub = $1", sub)
.fetch_optional(&state.db)
.await?;
let _id = match account {
Some(r) => r.id,
None => {
let id = Uuid::new_v4();
sqlx::query!("INSERT INTO accounts (id, sub) VALUES ($1, $2)", id, sub)
.execute(&state.db)
.await?;
id
}
};
Ok(())
}
Err(e) => {
tracing::error!("Could not finish OAuth2 flow: {e:?}");
Err(Error::InternalError)
}
}
}
#[tokio::main]
@ -50,6 +350,8 @@ async fn main() -> color_eyre::Result<()> {
let config = Settings::new()?;
tracing::info!("Settings: {config:#?}");
let oidc = OpenidConnector::new(&config).await?;
let addr: SocketAddr = format!("{}:{}", config.address, config.port).parse()?;
let db = PgPoolOptions::new()
@ -61,7 +363,10 @@ async fn main() -> color_eyre::Result<()> {
tracing::info!("Listening on {addr}");
let router = Router::new().with_state(Arc::new(AppState { db }));
let router = Router::new()
.route("/login", get(login))
.route("/login/redirect/:id", get(redirected))
.with_state(Arc::new(AppState { db, oidc }));
Ok(axum::Server::bind(&addr)
.serve(router.into_make_service())