2023-08-15 15:57:19 +02:00
|
|
|
use std::{
|
|
|
|
|
collections::{HashMap, VecDeque},
|
|
|
|
|
net::SocketAddr,
|
|
|
|
|
sync::Arc,
|
|
|
|
|
};
|
2023-08-14 17:59:37 +02:00
|
|
|
|
2023-08-15 15:57:19 +02:00
|
|
|
use axum::{
|
|
|
|
|
extract::{Path, Query, State},
|
|
|
|
|
http::StatusCode,
|
2023-08-15 19:40:11 +02:00
|
|
|
response::{Html, IntoResponse, Redirect},
|
2023-08-15 15:57:19 +02:00
|
|
|
routing::get,
|
|
|
|
|
Router,
|
|
|
|
|
};
|
|
|
|
|
use color_eyre::eyre;
|
2023-08-15 19:46:26 +02:00
|
|
|
use once_cell::sync::Lazy;
|
2023-08-15 15:57:19 +02:00
|
|
|
use openidconnect::{
|
|
|
|
|
core::{CoreAuthenticationFlow, CoreClient, CoreProviderMetadata},
|
|
|
|
|
url::Url,
|
|
|
|
|
AccessTokenHash, AuthorizationCode, ClientId, ClientSecret, CsrfToken, IssuerUrl, Nonce,
|
|
|
|
|
OAuth2TokenResponse, PkceCodeChallenge, PkceCodeVerifier, RedirectUrl, TokenResponse,
|
|
|
|
|
};
|
|
|
|
|
use serde::{Deserialize, Deserializer};
|
2023-08-14 18:26:38 +02:00
|
|
|
use sqlx::{postgres::PgPoolOptions, PgPool};
|
2023-08-15 19:40:11 +02:00
|
|
|
use tera::Tera;
|
2023-08-14 17:59:37 +02:00
|
|
|
use tracing_subscriber::EnvFilter;
|
2023-08-15 15:57:19 +02:00
|
|
|
use uuid::Uuid;
|
2023-08-14 17:59:37 +02:00
|
|
|
|
|
|
|
|
fn default_port() -> u16 {
|
|
|
|
|
8080
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
fn default_address() -> String {
|
|
|
|
|
"127.0.0.1".into()
|
|
|
|
|
}
|
|
|
|
|
|
2023-08-15 15:57:19 +02:00
|
|
|
fn deserialize_comma<'de, D>(de: D) -> Result<Vec<openidconnect::Scope>, D::Error>
|
|
|
|
|
where
|
|
|
|
|
D: Deserializer<'de>,
|
|
|
|
|
{
|
|
|
|
|
use serde::de::Visitor;
|
|
|
|
|
|
|
|
|
|
struct CommaVisitor;
|
|
|
|
|
impl<'de> Visitor<'de> for CommaVisitor {
|
|
|
|
|
type Value = Vec<openidconnect::Scope>;
|
|
|
|
|
|
|
|
|
|
fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
|
|
|
|
|
formatter.write_str("string containg comma separated strings")
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
fn visit_str<E>(self, v: &str) -> Result<Self::Value, E>
|
|
|
|
|
where
|
|
|
|
|
E: serde::de::Error,
|
|
|
|
|
{
|
|
|
|
|
Ok(v.split(',')
|
|
|
|
|
.map(|v| openidconnect::Scope::new(v.to_string()))
|
|
|
|
|
.collect())
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
de.deserialize_str(CommaVisitor)
|
|
|
|
|
}
|
|
|
|
|
|
2023-08-14 17:59:37 +02:00
|
|
|
#[derive(Deserialize, Debug)]
|
|
|
|
|
#[serde(rename_all = "UPPERCASE")]
|
|
|
|
|
struct Settings {
|
|
|
|
|
#[serde(default = "default_port")]
|
|
|
|
|
port: u16,
|
|
|
|
|
#[serde(default = "default_address")]
|
|
|
|
|
address: String,
|
2023-08-14 18:26:38 +02:00
|
|
|
database_url: String,
|
2023-08-15 15:57:19 +02:00
|
|
|
|
|
|
|
|
domain: String,
|
|
|
|
|
|
|
|
|
|
oidc_endpoint: String,
|
|
|
|
|
client_id: ClientId,
|
|
|
|
|
client_secret: ClientSecret,
|
|
|
|
|
#[serde(deserialize_with = "deserialize_comma")]
|
|
|
|
|
scopes: Vec<openidconnect::Scope>,
|
2023-08-14 17:59:37 +02:00
|
|
|
}
|
|
|
|
|
|
|
|
|
|
impl Settings {
|
|
|
|
|
fn new() -> color_eyre::Result<Self> {
|
|
|
|
|
envious::Config::default()
|
|
|
|
|
.with_prefix("MAIL_ADMIN_")
|
|
|
|
|
.case_sensitive(true)
|
|
|
|
|
.build_from_env()
|
|
|
|
|
.map_err(Into::into)
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
2023-08-15 15:57:19 +02:00
|
|
|
struct OpenidConnector {
|
|
|
|
|
provider: CoreClient,
|
|
|
|
|
scopes: Vec<openidconnect::Scope>,
|
|
|
|
|
redirect_base: Url,
|
|
|
|
|
inflight: parking_lot::Mutex<FifoMap>,
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
struct OpenidAuthState {
|
|
|
|
|
pkce_verifier: PkceCodeVerifier,
|
|
|
|
|
csrf_token: CsrfToken,
|
|
|
|
|
nonce: Nonce,
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
struct FifoMap {
|
|
|
|
|
max: usize,
|
|
|
|
|
order: VecDeque<Uuid>,
|
|
|
|
|
values: HashMap<Uuid, OpenidAuthState>,
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
struct FifoMapInsert<'a> {
|
|
|
|
|
pub id: Uuid,
|
|
|
|
|
map: &'a mut FifoMap,
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
impl<'a> FifoMapInsert<'a> {
|
|
|
|
|
pub fn insert(self, state: OpenidAuthState) {
|
|
|
|
|
let FifoMapInsert { id, map } = self;
|
|
|
|
|
|
|
|
|
|
map.order.push_back(id);
|
|
|
|
|
if map.order.len() > map.max {
|
|
|
|
|
let first = map.order.pop_front().unwrap();
|
|
|
|
|
map.values.remove(&first);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
map.values.insert(id, state);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
impl FifoMap {
|
|
|
|
|
pub fn new(max: usize) -> Self {
|
|
|
|
|
Self {
|
|
|
|
|
max,
|
|
|
|
|
order: VecDeque::new(),
|
|
|
|
|
values: HashMap::new(),
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
pub fn new_entry(&mut self) -> FifoMapInsert {
|
|
|
|
|
let id = loop {
|
|
|
|
|
let id = Uuid::new_v4();
|
|
|
|
|
if !self.values.contains_key(&id) {
|
|
|
|
|
break id;
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
FifoMapInsert { id, map: self }
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
pub fn remove(&mut self, id: Uuid) -> Option<OpenidAuthState> {
|
|
|
|
|
let state = self.values.remove(&id);
|
|
|
|
|
|
|
|
|
|
if let Some(idx) = self.order.iter().position(|&v| v == id) {
|
|
|
|
|
self.order.remove(idx);
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
state
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
impl OpenidConnector {
|
|
|
|
|
async fn new(settings: &Settings) -> color_eyre::Result<Self> {
|
|
|
|
|
let metadata = CoreProviderMetadata::discover_async(
|
|
|
|
|
IssuerUrl::new(settings.oidc_endpoint.clone())?,
|
|
|
|
|
openidconnect::reqwest::async_http_client,
|
|
|
|
|
)
|
|
|
|
|
.await?;
|
|
|
|
|
|
|
|
|
|
let provider = CoreClient::from_provider_metadata(
|
|
|
|
|
metadata,
|
|
|
|
|
settings.client_id.clone(),
|
|
|
|
|
Some(settings.client_secret.clone()),
|
|
|
|
|
);
|
|
|
|
|
|
|
|
|
|
Ok(Self {
|
|
|
|
|
provider,
|
|
|
|
|
scopes: settings.scopes.clone(),
|
|
|
|
|
redirect_base: (settings.domain.clone() + "/login/redirect/").parse()?,
|
|
|
|
|
inflight: parking_lot::Mutex::new(FifoMap::new(1024)),
|
|
|
|
|
})
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
pub fn start_auth(&self) -> Url {
|
|
|
|
|
let (pkce_challenge, pkce_verifier) = PkceCodeChallenge::new_random_sha256();
|
|
|
|
|
|
|
|
|
|
let mut inflight = self.inflight.lock();
|
|
|
|
|
|
|
|
|
|
let slot = inflight.new_entry();
|
|
|
|
|
|
|
|
|
|
tracing::info!("Login flow with id {}", slot.id);
|
|
|
|
|
|
|
|
|
|
let (url, csrf_token, nonce) = self
|
|
|
|
|
.provider
|
|
|
|
|
.authorize_url(
|
|
|
|
|
CoreAuthenticationFlow::AuthorizationCode,
|
|
|
|
|
CsrfToken::new_random,
|
|
|
|
|
Nonce::new_random,
|
|
|
|
|
)
|
|
|
|
|
.add_scopes(self.scopes.iter().cloned())
|
|
|
|
|
.set_pkce_challenge(pkce_challenge)
|
|
|
|
|
.set_redirect_uri(std::borrow::Cow::Owned(RedirectUrl::from_url(
|
|
|
|
|
self.redirect_base.join(&slot.id.to_string()).unwrap(),
|
|
|
|
|
)))
|
|
|
|
|
.url();
|
|
|
|
|
|
|
|
|
|
slot.insert(OpenidAuthState {
|
|
|
|
|
pkce_verifier,
|
|
|
|
|
csrf_token,
|
|
|
|
|
nonce,
|
|
|
|
|
});
|
|
|
|
|
|
|
|
|
|
url
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
pub async fn redirected(
|
|
|
|
|
&self,
|
|
|
|
|
id: Uuid,
|
|
|
|
|
csrf_state: String,
|
|
|
|
|
code: String,
|
|
|
|
|
) -> color_eyre::Result<String> {
|
|
|
|
|
let Some(state) = self.inflight.lock().remove(id) else {
|
|
|
|
|
eyre::bail!("Redirect has expired or has already been submitted")
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
tracing::info!("Redirected to continue login flow for {id}");
|
|
|
|
|
|
|
|
|
|
if state.csrf_token.secret() != &csrf_state {
|
|
|
|
|
eyre::bail!(
|
|
|
|
|
"CRSF state does not match, expected: {} got {csrf_state}",
|
|
|
|
|
state.csrf_token.secret()
|
|
|
|
|
)
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
let token_response = self
|
|
|
|
|
.provider
|
|
|
|
|
.exchange_code(AuthorizationCode::new(code))
|
|
|
|
|
.set_pkce_verifier(state.pkce_verifier)
|
|
|
|
|
.set_redirect_uri(std::borrow::Cow::Owned(RedirectUrl::from_url(
|
|
|
|
|
self.redirect_base.join(&id.to_string()).unwrap(),
|
|
|
|
|
)))
|
|
|
|
|
.request_async(openidconnect::reqwest::async_http_client)
|
|
|
|
|
.await?;
|
|
|
|
|
|
|
|
|
|
let Some(id_token) = token_response.id_token() else {
|
|
|
|
|
eyre::bail!("Server did not return an ID token")
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
let claims = id_token.claims(&self.provider.id_token_verifier(), &state.nonce)?;
|
|
|
|
|
|
|
|
|
|
if let Some(expected_access_token_hash) = claims.access_token_hash() {
|
|
|
|
|
let actual_access_token_hash = AccessTokenHash::from_token(
|
|
|
|
|
token_response.access_token(),
|
|
|
|
|
&id_token.signing_alg()?,
|
|
|
|
|
)?;
|
|
|
|
|
|
|
|
|
|
if actual_access_token_hash != *expected_access_token_hash {
|
|
|
|
|
eyre::bail!("Invalid access token")
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
Ok(claims.subject().to_string())
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
2023-08-14 18:26:38 +02:00
|
|
|
struct AppState {
|
|
|
|
|
db: PgPool,
|
2023-08-15 15:57:19 +02:00
|
|
|
oidc: OpenidConnector,
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
#[derive(thiserror::Error, Debug)]
|
|
|
|
|
enum Error {
|
|
|
|
|
#[error("An error occured in the database")]
|
2023-08-15 19:40:11 +02:00
|
|
|
Db(#[from] sqlx::Error),
|
|
|
|
|
#[error("An error occured when rendering a template")]
|
|
|
|
|
Tera(#[from] tera::Error),
|
2023-08-15 15:57:19 +02:00
|
|
|
#[error("An internal error occured")]
|
|
|
|
|
InternalError,
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
struct InternalError;
|
|
|
|
|
|
|
|
|
|
impl IntoResponse for InternalError {
|
|
|
|
|
fn into_response(self) -> axum::response::Response {
|
2023-08-15 19:46:26 +02:00
|
|
|
(
|
|
|
|
|
StatusCode::INTERNAL_SERVER_ERROR,
|
|
|
|
|
match TEMPLATES.render("error.html", &global_context()) {
|
|
|
|
|
Ok(v) => Html(v).into_response(),
|
|
|
|
|
Err(e) => {
|
|
|
|
|
tracing::error!("Could not generate internal error: {e:?}");
|
|
|
|
|
"Internal Error".into_response()
|
|
|
|
|
}
|
|
|
|
|
},
|
|
|
|
|
)
|
|
|
|
|
.into_response()
|
2023-08-15 15:57:19 +02:00
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
impl IntoResponse for Error {
|
|
|
|
|
fn into_response(self) -> axum::response::Response {
|
|
|
|
|
match self {
|
|
|
|
|
Error::InternalError => InternalError.into_response(),
|
2023-08-15 19:40:11 +02:00
|
|
|
Error::Db(e) => {
|
2023-08-15 15:57:19 +02:00
|
|
|
tracing::error!("Database error: {e:?}");
|
|
|
|
|
InternalError.into_response()
|
|
|
|
|
}
|
2023-08-15 19:40:11 +02:00
|
|
|
Error::Tera(e) => {
|
|
|
|
|
tracing::error!("Tera error: {e:?}");
|
|
|
|
|
InternalError.into_response()
|
|
|
|
|
}
|
2023-08-15 15:57:19 +02:00
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
async fn login(state: State<Arc<AppState>>) -> Redirect {
|
|
|
|
|
let redirect = state.oidc.start_auth();
|
|
|
|
|
|
|
|
|
|
Redirect::to(redirect.as_str())
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
#[derive(Deserialize)]
|
|
|
|
|
struct OidcRedirectParams {
|
|
|
|
|
state: String,
|
|
|
|
|
code: String,
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
async fn redirected(
|
|
|
|
|
state: State<Arc<AppState>>,
|
|
|
|
|
Path(id): Path<Uuid>,
|
|
|
|
|
Query(redirect): Query<OidcRedirectParams>,
|
|
|
|
|
) -> Result<(), Error> {
|
|
|
|
|
match state
|
|
|
|
|
.oidc
|
|
|
|
|
.redirected(id, redirect.state, redirect.code)
|
|
|
|
|
.await
|
|
|
|
|
{
|
|
|
|
|
Ok(sub) => {
|
|
|
|
|
let account = sqlx::query!("SELECT id FROM accounts WHERE sub = $1", sub)
|
|
|
|
|
.fetch_optional(&state.db)
|
|
|
|
|
.await?;
|
|
|
|
|
let _id = match account {
|
|
|
|
|
Some(r) => r.id,
|
|
|
|
|
None => {
|
|
|
|
|
let id = Uuid::new_v4();
|
|
|
|
|
sqlx::query!("INSERT INTO accounts (id, sub) VALUES ($1, $2)", id, sub)
|
|
|
|
|
.execute(&state.db)
|
|
|
|
|
.await?;
|
|
|
|
|
id
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
Ok(())
|
|
|
|
|
}
|
|
|
|
|
Err(e) => {
|
|
|
|
|
tracing::error!("Could not finish OAuth2 flow: {e:?}");
|
|
|
|
|
Err(Error::InternalError)
|
|
|
|
|
}
|
|
|
|
|
}
|
2023-08-14 18:26:38 +02:00
|
|
|
}
|
|
|
|
|
|
2023-08-15 19:40:11 +02:00
|
|
|
fn global_context() -> tera::Context {
|
|
|
|
|
let mut ctx = tera::Context::new();
|
|
|
|
|
|
|
|
|
|
ctx.insert("title", "Mail accounts");
|
|
|
|
|
|
|
|
|
|
ctx
|
|
|
|
|
}
|
|
|
|
|
|
2023-08-15 19:46:26 +02:00
|
|
|
async fn page_not_found() -> Result<(StatusCode, Html<String>), Error> {
|
2023-08-15 19:40:11 +02:00
|
|
|
Ok((
|
|
|
|
|
StatusCode::NOT_FOUND,
|
2023-08-15 19:46:26 +02:00
|
|
|
Html(TEMPLATES.render("not_found.html", &global_context())?),
|
2023-08-15 19:40:11 +02:00
|
|
|
))
|
|
|
|
|
}
|
|
|
|
|
|
2023-08-15 19:46:26 +02:00
|
|
|
pub static TEMPLATES: Lazy<Tera> =
|
|
|
|
|
Lazy::new(|| Tera::new("templates/*.html").expect("Could not generate templates"));
|
|
|
|
|
|
2023-08-14 17:59:37 +02:00
|
|
|
#[tokio::main]
|
|
|
|
|
async fn main() -> color_eyre::Result<()> {
|
|
|
|
|
color_eyre::install()?;
|
|
|
|
|
|
|
|
|
|
tracing_subscriber::fmt()
|
|
|
|
|
.with_max_level(tracing::Level::DEBUG)
|
|
|
|
|
.with_target(true)
|
|
|
|
|
.with_env_filter(EnvFilter::from_default_env())
|
|
|
|
|
.init();
|
|
|
|
|
|
|
|
|
|
let config = Settings::new()?;
|
|
|
|
|
tracing::info!("Settings: {config:#?}");
|
|
|
|
|
|
2023-08-15 15:57:19 +02:00
|
|
|
let oidc = OpenidConnector::new(&config).await?;
|
|
|
|
|
|
2023-08-14 17:59:37 +02:00
|
|
|
let addr: SocketAddr = format!("{}:{}", config.address, config.port).parse()?;
|
|
|
|
|
|
2023-08-14 18:26:38 +02:00
|
|
|
let db = PgPoolOptions::new()
|
|
|
|
|
.max_connections(5)
|
|
|
|
|
.connect(&config.database_url)
|
|
|
|
|
.await?;
|
|
|
|
|
|
|
|
|
|
sqlx::migrate!().run(&db).await?;
|
|
|
|
|
|
2023-08-14 17:59:37 +02:00
|
|
|
tracing::info!("Listening on {addr}");
|
|
|
|
|
|
2023-08-15 15:57:19 +02:00
|
|
|
let router = Router::new()
|
|
|
|
|
.route("/login", get(login))
|
|
|
|
|
.route("/login/redirect/:id", get(redirected))
|
2023-08-15 19:40:11 +02:00
|
|
|
.fallback(page_not_found)
|
2023-08-15 19:46:26 +02:00
|
|
|
.with_state(Arc::new(AppState { db, oidc }));
|
2023-08-14 17:59:37 +02:00
|
|
|
|
|
|
|
|
Ok(axum::Server::bind(&addr)
|
|
|
|
|
.serve(router.into_make_service())
|
|
|
|
|
.await?)
|
|
|
|
|
}
|