Implement OIDC flow
This commit is contained in:
parent
c7e3468acc
commit
7bf18486bc
4 changed files with 1057 additions and 10 deletions
746
Cargo.lock
generated
746
Cargo.lock
generated
File diff suppressed because it is too large
Load diff
|
|
@ -5,14 +5,18 @@ authors = ["traxys <quentin@familleboyer.net>"]
|
|||
edition = "2021"
|
||||
|
||||
[dependencies]
|
||||
axum = "0.6.20"
|
||||
axum = { version = "0.6.20", features = ["query"] }
|
||||
color-eyre = "0.6.2"
|
||||
envious = "0.2.2"
|
||||
openidconnect = "3.3.0"
|
||||
parking_lot = "0.12.1"
|
||||
serde = { version = "1.0.183", features = ["derive"] }
|
||||
sqlx = { version = "0.7.1", features = ["runtime-tokio", "postgres", "uuid", "migrate"] }
|
||||
thiserror = "1.0.44"
|
||||
tokio = { version = "1.31.0", features = ["full"] }
|
||||
tracing = "0.1.37"
|
||||
tracing-subscriber = { version = "0.3.17", features = ["env-filter"] }
|
||||
uuid = { version = "1.4.1", features = ["v4", "serde"] }
|
||||
|
||||
[profile.dev.package.sqlx-macros]
|
||||
opt-level = 3
|
||||
|
|
|
|||
2
migrations/20230814172950_openid.sql
Normal file
2
migrations/20230814172950_openid.sql
Normal file
|
|
@ -0,0 +1,2 @@
|
|||
-- Add migration script here
|
||||
ALTER TABLE accounts ADD COLUMN sub text UNIQUE NOT NULL;
|
||||
313
src/main.rs
313
src/main.rs
|
|
@ -1,9 +1,27 @@
|
|||
use std::{net::SocketAddr, sync::Arc};
|
||||
use std::{
|
||||
collections::{HashMap, VecDeque},
|
||||
net::SocketAddr,
|
||||
sync::Arc,
|
||||
};
|
||||
|
||||
use axum::Router;
|
||||
use serde::Deserialize;
|
||||
use axum::{
|
||||
extract::{Path, Query, State},
|
||||
http::StatusCode,
|
||||
response::{IntoResponse, Redirect},
|
||||
routing::get,
|
||||
Router,
|
||||
};
|
||||
use color_eyre::eyre;
|
||||
use openidconnect::{
|
||||
core::{CoreAuthenticationFlow, CoreClient, CoreProviderMetadata},
|
||||
url::Url,
|
||||
AccessTokenHash, AuthorizationCode, ClientId, ClientSecret, CsrfToken, IssuerUrl, Nonce,
|
||||
OAuth2TokenResponse, PkceCodeChallenge, PkceCodeVerifier, RedirectUrl, TokenResponse,
|
||||
};
|
||||
use serde::{Deserialize, Deserializer};
|
||||
use sqlx::{postgres::PgPoolOptions, PgPool};
|
||||
use tracing_subscriber::EnvFilter;
|
||||
use uuid::Uuid;
|
||||
|
||||
fn default_port() -> u16 {
|
||||
8080
|
||||
|
|
@ -13,6 +31,33 @@ fn default_address() -> String {
|
|||
"127.0.0.1".into()
|
||||
}
|
||||
|
||||
fn deserialize_comma<'de, D>(de: D) -> Result<Vec<openidconnect::Scope>, D::Error>
|
||||
where
|
||||
D: Deserializer<'de>,
|
||||
{
|
||||
use serde::de::Visitor;
|
||||
|
||||
struct CommaVisitor;
|
||||
impl<'de> Visitor<'de> for CommaVisitor {
|
||||
type Value = Vec<openidconnect::Scope>;
|
||||
|
||||
fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
|
||||
formatter.write_str("string containg comma separated strings")
|
||||
}
|
||||
|
||||
fn visit_str<E>(self, v: &str) -> Result<Self::Value, E>
|
||||
where
|
||||
E: serde::de::Error,
|
||||
{
|
||||
Ok(v.split(',')
|
||||
.map(|v| openidconnect::Scope::new(v.to_string()))
|
||||
.collect())
|
||||
}
|
||||
}
|
||||
|
||||
de.deserialize_str(CommaVisitor)
|
||||
}
|
||||
|
||||
#[derive(Deserialize, Debug)]
|
||||
#[serde(rename_all = "UPPERCASE")]
|
||||
struct Settings {
|
||||
|
|
@ -21,6 +66,14 @@ struct Settings {
|
|||
#[serde(default = "default_address")]
|
||||
address: String,
|
||||
database_url: String,
|
||||
|
||||
domain: String,
|
||||
|
||||
oidc_endpoint: String,
|
||||
client_id: ClientId,
|
||||
client_secret: ClientSecret,
|
||||
#[serde(deserialize_with = "deserialize_comma")]
|
||||
scopes: Vec<openidconnect::Scope>,
|
||||
}
|
||||
|
||||
impl Settings {
|
||||
|
|
@ -33,8 +86,255 @@ impl Settings {
|
|||
}
|
||||
}
|
||||
|
||||
struct OpenidConnector {
|
||||
provider: CoreClient,
|
||||
scopes: Vec<openidconnect::Scope>,
|
||||
redirect_base: Url,
|
||||
inflight: parking_lot::Mutex<FifoMap>,
|
||||
}
|
||||
|
||||
struct OpenidAuthState {
|
||||
pkce_verifier: PkceCodeVerifier,
|
||||
csrf_token: CsrfToken,
|
||||
nonce: Nonce,
|
||||
}
|
||||
|
||||
struct FifoMap {
|
||||
max: usize,
|
||||
order: VecDeque<Uuid>,
|
||||
values: HashMap<Uuid, OpenidAuthState>,
|
||||
}
|
||||
|
||||
struct FifoMapInsert<'a> {
|
||||
pub id: Uuid,
|
||||
map: &'a mut FifoMap,
|
||||
}
|
||||
|
||||
impl<'a> FifoMapInsert<'a> {
|
||||
pub fn insert(self, state: OpenidAuthState) {
|
||||
let FifoMapInsert { id, map } = self;
|
||||
|
||||
map.order.push_back(id);
|
||||
if map.order.len() > map.max {
|
||||
let first = map.order.pop_front().unwrap();
|
||||
map.values.remove(&first);
|
||||
}
|
||||
|
||||
map.values.insert(id, state);
|
||||
}
|
||||
}
|
||||
|
||||
impl FifoMap {
|
||||
pub fn new(max: usize) -> Self {
|
||||
Self {
|
||||
max,
|
||||
order: VecDeque::new(),
|
||||
values: HashMap::new(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn new_entry(&mut self) -> FifoMapInsert {
|
||||
let id = loop {
|
||||
let id = Uuid::new_v4();
|
||||
if !self.values.contains_key(&id) {
|
||||
break id;
|
||||
}
|
||||
};
|
||||
|
||||
FifoMapInsert { id, map: self }
|
||||
}
|
||||
|
||||
pub fn remove(&mut self, id: Uuid) -> Option<OpenidAuthState> {
|
||||
let state = self.values.remove(&id);
|
||||
|
||||
if let Some(idx) = self.order.iter().position(|&v| v == id) {
|
||||
self.order.remove(idx);
|
||||
};
|
||||
|
||||
state
|
||||
}
|
||||
}
|
||||
|
||||
impl OpenidConnector {
|
||||
async fn new(settings: &Settings) -> color_eyre::Result<Self> {
|
||||
let metadata = CoreProviderMetadata::discover_async(
|
||||
IssuerUrl::new(settings.oidc_endpoint.clone())?,
|
||||
openidconnect::reqwest::async_http_client,
|
||||
)
|
||||
.await?;
|
||||
|
||||
let provider = CoreClient::from_provider_metadata(
|
||||
metadata,
|
||||
settings.client_id.clone(),
|
||||
Some(settings.client_secret.clone()),
|
||||
);
|
||||
|
||||
Ok(Self {
|
||||
provider,
|
||||
scopes: settings.scopes.clone(),
|
||||
redirect_base: (settings.domain.clone() + "/login/redirect/").parse()?,
|
||||
inflight: parking_lot::Mutex::new(FifoMap::new(1024)),
|
||||
})
|
||||
}
|
||||
|
||||
pub fn start_auth(&self) -> Url {
|
||||
let (pkce_challenge, pkce_verifier) = PkceCodeChallenge::new_random_sha256();
|
||||
|
||||
let mut inflight = self.inflight.lock();
|
||||
|
||||
let slot = inflight.new_entry();
|
||||
|
||||
tracing::info!("Login flow with id {}", slot.id);
|
||||
|
||||
let (url, csrf_token, nonce) = self
|
||||
.provider
|
||||
.authorize_url(
|
||||
CoreAuthenticationFlow::AuthorizationCode,
|
||||
CsrfToken::new_random,
|
||||
Nonce::new_random,
|
||||
)
|
||||
.add_scopes(self.scopes.iter().cloned())
|
||||
.set_pkce_challenge(pkce_challenge)
|
||||
.set_redirect_uri(std::borrow::Cow::Owned(RedirectUrl::from_url(
|
||||
self.redirect_base.join(&slot.id.to_string()).unwrap(),
|
||||
)))
|
||||
.url();
|
||||
|
||||
slot.insert(OpenidAuthState {
|
||||
pkce_verifier,
|
||||
csrf_token,
|
||||
nonce,
|
||||
});
|
||||
|
||||
url
|
||||
}
|
||||
|
||||
pub async fn redirected(
|
||||
&self,
|
||||
id: Uuid,
|
||||
csrf_state: String,
|
||||
code: String,
|
||||
) -> color_eyre::Result<String> {
|
||||
let Some(state) = self.inflight.lock().remove(id) else {
|
||||
eyre::bail!("Redirect has expired or has already been submitted")
|
||||
};
|
||||
|
||||
tracing::info!("Redirected to continue login flow for {id}");
|
||||
|
||||
if state.csrf_token.secret() != &csrf_state {
|
||||
eyre::bail!(
|
||||
"CRSF state does not match, expected: {} got {csrf_state}",
|
||||
state.csrf_token.secret()
|
||||
)
|
||||
}
|
||||
|
||||
let token_response = self
|
||||
.provider
|
||||
.exchange_code(AuthorizationCode::new(code))
|
||||
.set_pkce_verifier(state.pkce_verifier)
|
||||
.set_redirect_uri(std::borrow::Cow::Owned(RedirectUrl::from_url(
|
||||
self.redirect_base.join(&id.to_string()).unwrap(),
|
||||
)))
|
||||
.request_async(openidconnect::reqwest::async_http_client)
|
||||
.await?;
|
||||
|
||||
let Some(id_token) = token_response.id_token() else {
|
||||
eyre::bail!("Server did not return an ID token")
|
||||
};
|
||||
|
||||
let claims = id_token.claims(&self.provider.id_token_verifier(), &state.nonce)?;
|
||||
|
||||
if let Some(expected_access_token_hash) = claims.access_token_hash() {
|
||||
let actual_access_token_hash = AccessTokenHash::from_token(
|
||||
token_response.access_token(),
|
||||
&id_token.signing_alg()?,
|
||||
)?;
|
||||
|
||||
if actual_access_token_hash != *expected_access_token_hash {
|
||||
eyre::bail!("Invalid access token")
|
||||
}
|
||||
}
|
||||
|
||||
Ok(claims.subject().to_string())
|
||||
}
|
||||
}
|
||||
|
||||
struct AppState {
|
||||
db: PgPool,
|
||||
oidc: OpenidConnector,
|
||||
}
|
||||
|
||||
#[derive(thiserror::Error, Debug)]
|
||||
enum Error {
|
||||
#[error("An error occured in the database")]
|
||||
DbError(#[from] sqlx::Error),
|
||||
#[error("An internal error occured")]
|
||||
InternalError,
|
||||
}
|
||||
|
||||
struct InternalError;
|
||||
|
||||
impl IntoResponse for InternalError {
|
||||
fn into_response(self) -> axum::response::Response {
|
||||
(StatusCode::INTERNAL_SERVER_ERROR, "Internal Error").into_response()
|
||||
}
|
||||
}
|
||||
|
||||
impl IntoResponse for Error {
|
||||
fn into_response(self) -> axum::response::Response {
|
||||
match self {
|
||||
Error::InternalError => InternalError.into_response(),
|
||||
Error::DbError(e) => {
|
||||
tracing::error!("Database error: {e:?}");
|
||||
InternalError.into_response()
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async fn login(state: State<Arc<AppState>>) -> Redirect {
|
||||
let redirect = state.oidc.start_auth();
|
||||
|
||||
Redirect::to(redirect.as_str())
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
struct OidcRedirectParams {
|
||||
state: String,
|
||||
code: String,
|
||||
}
|
||||
|
||||
async fn redirected(
|
||||
state: State<Arc<AppState>>,
|
||||
Path(id): Path<Uuid>,
|
||||
Query(redirect): Query<OidcRedirectParams>,
|
||||
) -> Result<(), Error> {
|
||||
match state
|
||||
.oidc
|
||||
.redirected(id, redirect.state, redirect.code)
|
||||
.await
|
||||
{
|
||||
Ok(sub) => {
|
||||
let account = sqlx::query!("SELECT id FROM accounts WHERE sub = $1", sub)
|
||||
.fetch_optional(&state.db)
|
||||
.await?;
|
||||
let _id = match account {
|
||||
Some(r) => r.id,
|
||||
None => {
|
||||
let id = Uuid::new_v4();
|
||||
sqlx::query!("INSERT INTO accounts (id, sub) VALUES ($1, $2)", id, sub)
|
||||
.execute(&state.db)
|
||||
.await?;
|
||||
id
|
||||
}
|
||||
};
|
||||
Ok(())
|
||||
}
|
||||
Err(e) => {
|
||||
tracing::error!("Could not finish OAuth2 flow: {e:?}");
|
||||
Err(Error::InternalError)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::main]
|
||||
|
|
@ -50,6 +350,8 @@ async fn main() -> color_eyre::Result<()> {
|
|||
let config = Settings::new()?;
|
||||
tracing::info!("Settings: {config:#?}");
|
||||
|
||||
let oidc = OpenidConnector::new(&config).await?;
|
||||
|
||||
let addr: SocketAddr = format!("{}:{}", config.address, config.port).parse()?;
|
||||
|
||||
let db = PgPoolOptions::new()
|
||||
|
|
@ -61,7 +363,10 @@ async fn main() -> color_eyre::Result<()> {
|
|||
|
||||
tracing::info!("Listening on {addr}");
|
||||
|
||||
let router = Router::new().with_state(Arc::new(AppState { db }));
|
||||
let router = Router::new()
|
||||
.route("/login", get(login))
|
||||
.route("/login/redirect/:id", get(redirected))
|
||||
.with_state(Arc::new(AppState { db, oidc }));
|
||||
|
||||
Ok(axum::Server::bind(&addr)
|
||||
.serve(router.into_make_service())
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue