use std::{ collections::{HashMap, VecDeque}, fmt::Display, net::SocketAddr, sync::Arc, }; use argon2::{ password_hash::{rand_core::OsRng, SaltString}, Argon2, PasswordHasher, }; use axum::{ async_trait, extract::{FromRef, FromRequestParts, Path, Query, State}, http::{request::Parts, StatusCode}, response::{Html, IntoResponse, Redirect}, routing::{get, post}, Form, Router, }; use axum_extra::extract::{cookie::Cookie, CookieJar}; use base64::{engine::general_purpose, engine::Engine}; use color_eyre::eyre; use cookie::{time::OffsetDateTime, SameSite}; use futures_util::TryStreamExt; use itertools::Itertools; use jwt_simple::prelude::*; use once_cell::sync::Lazy; use openidconnect::{ core::{CoreAuthenticationFlow, CoreClient, CoreProviderMetadata}, url::Url, AccessTokenHash, AuthorizationCode, ClientId, ClientSecret, CsrfToken, IssuerUrl, Nonce, OAuth2TokenResponse, PkceCodeChallenge, PkceCodeVerifier, RedirectUrl, TokenResponse, }; use secrecy::{ExposeSecret, SecretString}; use serde::{Deserialize, Deserializer}; use sqlx::{postgres::PgPoolOptions, PgPool}; use tera::Tera; use tracing_subscriber::EnvFilter; use uuid::Uuid; fn default_port() -> u16 { 8080 } fn default_address() -> String { "127.0.0.1".into() } #[derive(Clone)] pub(crate) struct Base64(pub(crate) HS256Key); impl std::fmt::Debug for Base64 { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { write!( f, r#"b64"{}""#, &general_purpose::STANDARD.encode(self.0.to_bytes()) ) } } impl<'de> Deserialize<'de> for Base64 { fn deserialize(de: D) -> Result where D: Deserializer<'de>, { use serde::de::Visitor; struct DecodingVisitor; impl<'de> Visitor<'de> for DecodingVisitor { type Value = Base64; fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result { formatter.write_str("must be a base 64 string") } fn visit_str(self, v: &str) -> Result where E: serde::de::Error, { general_purpose::STANDARD .decode(v) .map_err(E::custom) .map(|b| HS256Key::from_bytes(&b)) .map(Base64) } } de.deserialize_str(DecodingVisitor) } } fn deserialize_comma<'de, D>(de: D) -> Result, D::Error> where D: Deserializer<'de>, { use serde::de::Visitor; struct CommaVisitor; impl<'de> Visitor<'de> for CommaVisitor { type Value = Vec; fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result { formatter.write_str("string containg comma separated strings") } fn visit_str(self, v: &str) -> Result where E: serde::de::Error, { Ok(v.split(',') .map(|v| openidconnect::Scope::new(v.to_string())) .collect()) } } de.deserialize_str(CommaVisitor) } #[derive(Deserialize, Debug)] #[serde(rename_all = "UPPERCASE")] struct Settings { jwt_secret: Base64, #[serde(default = "default_port")] port: u16, #[serde(default = "default_address")] address: String, database_url: String, mail_domain: String, domain: String, oidc_endpoint: String, client_id: ClientId, client_secret: ClientSecret, #[serde(deserialize_with = "deserialize_comma")] scopes: Vec, } impl Settings { fn new() -> color_eyre::Result { envious::Config::default() .with_prefix("MAIL_ADMIN_") .case_sensitive(true) .build_from_env() .map_err(Into::into) } } struct OpenidConnector { provider: CoreClient, scopes: Vec, redirect_base: Url, inflight: parking_lot::Mutex, } struct OpenidAuthState { pkce_verifier: PkceCodeVerifier, csrf_token: CsrfToken, nonce: Nonce, } struct FifoMap { max: usize, order: VecDeque, values: HashMap, } struct FifoMapInsert<'a> { pub id: Uuid, map: &'a mut FifoMap, } impl<'a> FifoMapInsert<'a> { pub fn insert(self, state: OpenidAuthState) { let FifoMapInsert { id, map } = self; map.order.push_back(id); if map.order.len() > map.max { let first = map.order.pop_front().unwrap(); map.values.remove(&first); } map.values.insert(id, state); } } impl FifoMap { pub fn new(max: usize) -> Self { Self { max, order: VecDeque::new(), values: HashMap::new(), } } pub fn new_entry(&mut self) -> FifoMapInsert { let id = loop { let id = Uuid::new_v4(); if !self.values.contains_key(&id) { break id; } }; FifoMapInsert { id, map: self } } pub fn remove(&mut self, id: Uuid) -> Option { let state = self.values.remove(&id); if let Some(idx) = self.order.iter().position(|&v| v == id) { self.order.remove(idx); }; state } } impl OpenidConnector { async fn new(settings: &Settings) -> color_eyre::Result { let metadata = CoreProviderMetadata::discover_async( IssuerUrl::new(settings.oidc_endpoint.clone())?, openidconnect::reqwest::async_http_client, ) .await?; let provider = CoreClient::from_provider_metadata( metadata, settings.client_id.clone(), Some(settings.client_secret.clone()), ); Ok(Self { provider, scopes: settings.scopes.clone(), redirect_base: (settings.domain.clone() + "/login/redirect/").parse()?, inflight: parking_lot::Mutex::new(FifoMap::new(1024)), }) } pub fn start_auth(&self) -> Url { let (pkce_challenge, pkce_verifier) = PkceCodeChallenge::new_random_sha256(); let mut inflight = self.inflight.lock(); let slot = inflight.new_entry(); tracing::info!("Login flow with id {}", slot.id); let (url, csrf_token, nonce) = self .provider .authorize_url( CoreAuthenticationFlow::AuthorizationCode, CsrfToken::new_random, Nonce::new_random, ) .add_scopes(self.scopes.iter().cloned()) .set_pkce_challenge(pkce_challenge) .set_redirect_uri(std::borrow::Cow::Owned(RedirectUrl::from_url( self.redirect_base.join(&slot.id.to_string()).unwrap(), ))) .url(); slot.insert(OpenidAuthState { pkce_verifier, csrf_token, nonce, }); url } pub async fn redirected( &self, id: Uuid, csrf_state: String, code: String, ) -> color_eyre::Result { let Some(state) = self.inflight.lock().remove(id) else { eyre::bail!("Redirect has expired or has already been submitted") }; tracing::info!("Redirected to continue login flow for {id}"); if state.csrf_token.secret() != &csrf_state { eyre::bail!( "CRSF state does not match, expected: {} got {csrf_state}", state.csrf_token.secret() ) } let token_response = self .provider .exchange_code(AuthorizationCode::new(code)) .set_pkce_verifier(state.pkce_verifier) .set_redirect_uri(std::borrow::Cow::Owned(RedirectUrl::from_url( self.redirect_base.join(&id.to_string()).unwrap(), ))) .request_async(openidconnect::reqwest::async_http_client) .await?; let Some(id_token) = token_response.id_token() else { eyre::bail!("Server did not return an ID token") }; let claims = id_token.claims(&self.provider.id_token_verifier(), &state.nonce)?; if let Some(expected_access_token_hash) = claims.access_token_hash() { let actual_access_token_hash = AccessTokenHash::from_token( token_response.access_token(), &id_token.signing_alg()?, )?; if actual_access_token_hash != *expected_access_token_hash { eyre::bail!("Invalid access token") } } Ok(claims.subject().to_string()) } } struct AppState { jwt_secret: HS256Key, db: PgPool, mail_domain: String, oidc: OpenidConnector, } #[derive(thiserror::Error, Debug)] enum Error { #[error("An error occured in the database")] Db(#[from] sqlx::Error), #[error("An error occured when rendering a template")] Tera(#[from] tera::Error), #[error("A JWT error occured")] Jwt(#[from] jwt_simple::Error), #[error("An argon2 error occured")] Argon2(#[from] argon2::password_hash::Error), #[error("An internal error occured")] InternalError, } struct InternalError; impl IntoResponse for InternalError { fn into_response(self) -> axum::response::Response { ( StatusCode::INTERNAL_SERVER_ERROR, match TEMPLATES.render("error.html", &global_context()) { Ok(v) => Html(v).into_response(), Err(e) => { tracing::error!("Could not generate internal error: {e:?}"); "Internal Error".into_response() } }, ) .into_response() } } impl IntoResponse for Error { fn into_response(self) -> axum::response::Response { match self { Error::InternalError => InternalError.into_response(), Error::Db(e) => { tracing::error!("Database error: {e:?}"); InternalError.into_response() } Error::Tera(e) => { tracing::error!("Tera error: {e:?}"); InternalError.into_response() } Error::Jwt(e) => { tracing::error!("JWT error: {e:?}"); InternalError.into_response() } Error::Argon2(e) => { tracing::error!("Argon2 error: {e:?}"); InternalError.into_response() } } } } async fn login(state: State>) -> Redirect { let redirect = state.oidc.start_auth(); Redirect::to(redirect.as_str()) } #[derive(Deserialize)] struct OidcRedirectParams { state: String, code: String, } pub static AUTH_COOKIE: &str = "mail_admin_token"; async fn redirected( state: State>, Path(id): Path, Query(redirect): Query, jar: CookieJar, ) -> Result<(CookieJar, Redirect), Error> { match state .oidc .redirected(id, redirect.state, redirect.code) .await { Ok(sub) => { let account = sqlx::query!("SELECT id FROM accounts WHERE sub = $1", sub) .fetch_optional(&state.db) .await?; let id = match account { Some(r) => r.id, None => { let id = Uuid::new_v4(); sqlx::query!("INSERT INTO accounts (id, sub) VALUES ($1, $2)", id, sub) .execute(&state.db) .await?; id } }; let expire = std::time::Duration::from_secs(3600 * 24 * 31 * 6); let mut claims = Claims::create(expire.into()); claims.subject = Some(id.to_string()); let token = state.jwt_secret.authenticate(claims)?; let mut cookie = Cookie::named(AUTH_COOKIE); cookie.set_value(token); cookie.set_http_only(true); let mut now = OffsetDateTime::now_utc(); now += expire; cookie.set_expires(now); cookie.set_same_site(SameSite::None); cookie.set_secure(true); cookie.set_path("/"); let jar = jar.add(cookie); Ok((jar, Redirect::to("/"))) } Err(e) => { tracing::error!("Could not finish OAuth2 flow: {e:?}"); Err(Error::InternalError) } } } fn global_context() -> tera::Context { let mut ctx = tera::Context::new(); ctx.insert("title", "Mail accounts"); ctx } async fn page_not_found() -> Result<(StatusCode, Html), Error> { Ok(( StatusCode::NOT_FOUND, Html(TEMPLATES.render("not_found.html", &global_context())?), )) } pub static TEMPLATES: Lazy = Lazy::new(|| Tera::new("templates/*.html").expect("Could not generate templates")); struct User(Uuid); #[async_trait] impl FromRequestParts for User where S: Send + Sync, Arc: FromRef, { type Rejection = Redirect; async fn from_request_parts(parts: &mut Parts, state: &S) -> Result { let jar = CookieJar::from_request_parts(parts, state).await.unwrap(); match jar.get(AUTH_COOKIE) { None => { tracing::debug!("No auth token"); Err(Redirect::to("/login")) } Some(data) => { let state: Arc = State::from_request_parts(parts, state).await.unwrap().0; let claims = state .jwt_secret .verify_token::(data.value(), None) .map_err(|_| Redirect::to("/login"))?; Ok(User(claims.subject.unwrap().parse().unwrap())) } } } } #[derive(Serialize, Deserialize, Debug)] struct Mail { mail: String, } #[derive(Serialize, Deserialize)] struct HomeQuery { user_error: Option, } #[derive(Deserialize, Debug)] struct AliasRecipient { alias: String, recipient: String, } #[derive(Serialize)] struct Alias { mail: String, recipients: Vec, } async fn home( state: State>, User(user): User, Query(query): Query, ) -> Result, Error> { let mails = sqlx::query_as!( Mail, r#" SELECT mail FROM emails WHERE id = $1 AND alias = false ORDER BY lower(substring(mail from position('@' in mail)+1 )),lower(mail) "#, user ) .fetch_all(&state.db) .await?; let aliases = sqlx::query_as!( Mail, r#" SELECT mail FROM emails WHERE id = $1 AND alias = true ORDER BY lower(substring(mail from position('@' in mail)+1 )),lower(mail) "#, user ) .fetch_all(&state.db) .await?; let mut alias_stream = sqlx::query_as!( AliasRecipient, r#" SELECT alias_recipient.mail as alias,recipient FROM emails,alias_recipient WHERE id = $1 AND alias = true AND emails.mail = alias_recipient.mail "#, user ) .fetch(&state.db); let name = sqlx::query!("SELECT name FROM accounts WHERE id = $1", user) .fetch_optional(&state.db) .await? .and_then(|r| r.name); let mut aliases: HashMap<_, _> = aliases.into_iter().map(|a| (a.mail, Vec::new())).collect(); while let Some(alias) = alias_stream.try_next().await? { aliases.get_mut(&alias.alias).unwrap().push(alias.recipient); } let aliases: Vec<_> = aliases .into_iter() .map(|(mail, recipients)| Alias { mail, recipients }) .sorted_by(|a, b| a.mail.cmp(&b.mail)) .collect(); let mut context = tera::Context::new(); context.insert("mails", &mails); context.insert("mail_domain", &state.mail_domain); context.insert("aliases", &aliases); context.insert("name", &name); if let Some(err) = query.user_error { tracing::info!("User error: {err:?}"); context.insert("user_error", &err.to_string()); } context.extend(global_context()); Ok(Html(TEMPLATES.render("home.html", &context)?)) } #[tracing::instrument(skip(state))] async fn delete_mail( state: State>, User(user): User, Form(delete): Form, ) -> Result { let rows_affected = sqlx::query!( "DELETE FROM emails WHERE id = $1 AND mail = $2", user, delete.mail ) .execute(&state.db) .await? .rows_affected(); if rows_affected != 1 { tracing::warn!("Invalid number of rows affected in delete: {rows_affected}"); return Err(Error::InternalError); } Ok(Redirect::to("/")) } #[tracing::instrument(skip(state))] async fn delete_alias( state: State>, User(user): User, Form(delete): Form, ) -> Result { let mut tx = state.db.begin().await?; sqlx::query!("DELETE FROM alias_recipient WHERE mail = $1", delete.mail) .execute(&mut *tx) .await?; let rows_affected = sqlx::query!( "DELETE FROM emails WHERE id = $1 AND mail = $2", user, delete.mail ) .execute(&mut *tx) .await? .rows_affected(); if rows_affected != 1 { tracing::warn!("Invalid number of rows affected in delete: {rows_affected}"); tx.rollback().await?; return Err(Error::InternalError); } tx.commit().await?; Ok(Redirect::to("/")) } #[derive(Serialize, Deserialize, Debug)] enum UserError { MailAlreadyExists, NameAlreadyExists(String), } impl Display for UserError { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { match self { UserError::MailAlreadyExists => { write!(f, "email address is already used by another user") } UserError::NameAlreadyExists(n) => { write!(f, "account name '{n}' is already used by another user") } } } } impl From for Redirect { fn from(value: UserError) -> Self { Redirect::to(&format!( "/?{}", serde_urlencoded::to_string(HomeQuery { user_error: Some(value) }) .expect("could not generate user error") )) } } async fn add_mail( state: State>, User(user): User, Form(add): Form, ) -> Result { let has_mail = sqlx::query!( "SELECT COUNT(*) FROM emails WHERE id != $1 AND mail = $2", user, add.mail ) .fetch_one(&state.db) .await? .count .expect("count should not be null"); if has_mail != 0 { Ok(UserError::MailAlreadyExists.into()) } else { sqlx::query!( "INSERT INTO emails (id, mail) VALUES ($1, $2) ON CONFLICT DO NOTHING", user, add.mail ) .execute(&state.db) .await?; Ok(Redirect::to("/")) } } async fn add_alias( state: State>, User(user): User, Form(add): Form, ) -> Result { let has_mail = sqlx::query!( "SELECT COUNT(*) FROM emails WHERE id != $1 AND mail = $2", user, add.mail ) .fetch_one(&state.db) .await? .count .expect("count should not be null"); if has_mail != 0 { Ok(UserError::MailAlreadyExists.into()) } else { sqlx::query!( "INSERT INTO emails (id, mail, alias) VALUES ($1, $2, true) ON CONFLICT DO NOTHING", user, add.mail ) .execute(&state.db) .await?; Ok(Redirect::to("/")) } } #[tracing::instrument(skip(state))] async fn add_recipient( state: State>, User(user): User, Form(add): Form, ) -> Result { let can_use_alias = sqlx::query!( "SELECT COUNT(*) FROM emails WHERE id = $1 AND mail = $2", user, add.alias ) .fetch_one(&state.db) .await? .count .expect("count should not be null") > 0; if !can_use_alias { tracing::error!("User is not authorized to use this alias"); return Err(Error::InternalError); } sqlx::query!( "INSERT INTO alias_recipient (mail, recipient) VALUES ($1, $2) ON CONFLICT DO NOTHING", add.alias, add.recipient ) .execute(&state.db) .await?; Ok(Redirect::to("/")) } #[tracing::instrument(skip(state))] async fn delete_recipient( state: State>, User(user): User, Form(delete): Form, ) -> Result { let can_use_alias = sqlx::query!( "SELECT COUNT(*) FROM emails WHERE id = $1 AND mail = $2", user, delete.alias ) .fetch_one(&state.db) .await? .count .expect("count should not be null") > 0; if !can_use_alias { tracing::error!("User is not authorized to use this alias"); return Err(Error::InternalError); } let rows_affected = sqlx::query!( "DELETE FROM alias_recipient WHERE mail = $1 AND recipient = $2", delete.alias, delete.recipient, ) .execute(&state.db) .await? .rows_affected(); if rows_affected != 1 { tracing::warn!("Invalid number of rows affected in delete: {rows_affected}"); return Err(Error::InternalError); } Ok(Redirect::to("/")) } #[derive(Deserialize, Debug)] struct Password { password: SecretString, } #[tracing::instrument(skip(state))] async fn set_password( state: State>, User(user): User, Form(password): Form, ) -> Result { let salt = SaltString::generate(&mut OsRng); let argon2 = Argon2::default(); let password_hash = argon2 .hash_password(password.password.expose_secret().as_bytes(), &salt)? .to_string(); sqlx::query!( "UPDATE accounts SET password = $1 WHERE id = $2", password_hash, user ) .execute(&state.db) .await?; Ok(Redirect::to("/")) } #[derive(Deserialize, Debug)] struct Name { name: String, } #[tracing::instrument(skip(state))] async fn set_name( state: State>, User(user): User, Form(name): Form, ) -> Result { let taken = sqlx::query!( "SELECT COUNT(*) FROM accounts WHERE name = $1 AND id != $2", name.name, user ) .fetch_one(&state.db) .await? .count .expect("count returned null") != 0; if taken { Ok(UserError::NameAlreadyExists(name.name).into()) } else { sqlx::query!( "UPDATE accounts SET name = $1 WHERE id = $2", name.name, user ) .execute(&state.db) .await?; Ok(Redirect::to("/")) } } #[tokio::main] async fn main() -> color_eyre::Result<()> { color_eyre::install()?; tracing_subscriber::fmt() .with_max_level(tracing::Level::DEBUG) .with_target(true) .with_env_filter(EnvFilter::from_default_env()) .init(); let config = Settings::new()?; tracing::info!("Settings: {config:#?}"); let oidc = OpenidConnector::new(&config).await?; let addr: SocketAddr = format!("{}:{}", config.address, config.port).parse()?; let db = PgPoolOptions::new() .max_connections(5) .connect(&config.database_url) .await?; sqlx::migrate!().run(&db).await?; tracing::info!("Listening on {addr}"); let router = Router::new() .route("/login", get(login)) .route("/login/redirect/:id", get(redirected)) .route("/", get(home)) .route("/mail/delete", post(delete_mail)) .route("/mail/add", post(add_mail)) .route("/alias/add", post(add_alias)) .route("/alias/recipient/add", post(add_recipient)) .route("/alias/recipient/delete", post(delete_recipient)) .route("/alias/delete", post(delete_alias)) .route("/password", post(set_password)) .route("/name", post(set_name)) .fallback(page_not_found) .with_state(Arc::new(AppState { db, oidc, jwt_secret: config.jwt_secret.0, mail_domain: config.mail_domain, })); Ok(axum::Server::bind(&addr) .serve(router.into_make_service()) .await?) }