diff --git a/.djlintrc b/.djlintrc new file mode 100644 index 0000000..a5811ae --- /dev/null +++ b/.djlintrc @@ -0,0 +1,4 @@ +{ + "profile": "jinja", + "ignore": "D018" +} diff --git a/Cargo.lock b/Cargo.lock index 662bc02..42bc0cf 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1277,6 +1277,7 @@ dependencies = [ "openidconnect", "parking_lot", "serde", + "serde_urlencoded", "sqlx", "tera", "thiserror", diff --git a/Cargo.toml b/Cargo.toml index 7000315..fe6bcc9 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -16,6 +16,7 @@ once_cell = "1.18.0" openidconnect = "3.3.0" parking_lot = "0.12.1" serde = { version = "1.0.183", features = ["derive"] } +serde_urlencoded = "0.7.1" sqlx = { version = "0.7.1", features = ["runtime-tokio", "postgres", "uuid", "migrate"] } tera = "1.19.0" thiserror = "1.0.44" diff --git a/src/main.rs b/src/main.rs index c99d97f..31d0051 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,5 +1,6 @@ use std::{ collections::{HashMap, VecDeque}, + fmt::Display, net::SocketAddr, sync::Arc, }; @@ -119,6 +120,8 @@ struct Settings { address: String, database_url: String, + mail_domain: String, + domain: String, oidc_endpoint: String, @@ -314,6 +317,7 @@ impl OpenidConnector { struct AppState { jwt_secret: HS256Key, db: PgPool, + mail_domain: String, oidc: OpenidConnector, } @@ -486,33 +490,42 @@ where } } -#[derive(Serialize, Deserialize)] +#[derive(Serialize, Deserialize, Debug)] struct Mail { mail: String, } -async fn home(state: State>, User(user): User) -> Result, Error> { +#[derive(Serialize, Deserialize)] +struct HomeQuery { + user_error: Option, +} + +async fn home( + state: State>, + User(user): User, + Query(query): Query, +) -> Result, Error> { let mails = sqlx::query_as!(Mail, "SELECT mail FROM emails WHERE id = $1", user) .fetch_all(&state.db) .await?; let mut context = tera::Context::new(); context.insert("mails", &mails); + context.insert("mail_domain", &state.mail_domain); + if let Some(err) = query.user_error { + tracing::info!("User error: {err:?}"); + context.insert("user_error", &err.to_string()); + } context.extend(global_context()); Ok(Html(TEMPLATES.render("home.html", &context)?)) } -#[derive(Deserialize, Debug)] -struct MailDelete { - mail: String, -} - #[tracing::instrument(skip(state))] async fn delete_mail( state: State>, User(user): User, - Form(delete): Form, + Form(delete): Form, ) -> Result { let rows_affected = sqlx::query!( "DELETE FROM emails WHERE id = $1 AND mail = $2", @@ -531,6 +544,57 @@ async fn delete_mail( Ok(Redirect::to("/")) } +#[derive(Serialize, Deserialize, Debug)] +enum UserError { + MailAlreadyExists, +} + +impl Display for UserError { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + UserError::MailAlreadyExists => { + write!(f, "email address is already used by another user") + } + } + } +} + +async fn add_mail( + state: State>, + User(user): User, + Form(add): Form, +) -> Result { + let has_mail = sqlx::query!( + "SELECT COUNT(*) FROM emails WHERE id != $1 AND mail = $2", + user, + add.mail + ) + .fetch_one(&state.db) + .await? + .count + .expect("count should not be null"); + + if has_mail != 0 { + Ok(Redirect::to(&format!( + "/?{}", + serde_urlencoded::to_string(&HomeQuery { + user_error: Some(UserError::MailAlreadyExists) + }) + .expect("could not generate query") + ))) + } else { + sqlx::query!( + "INSERT INTO emails (id, mail) VALUES ($1, $2) ON CONFLICT DO NOTHING", + user, + add.mail + ) + .execute(&state.db) + .await?; + + Ok(Redirect::to("/")) + } +} + #[tokio::main] async fn main() -> color_eyre::Result<()> { color_eyre::install()?; @@ -562,11 +626,13 @@ async fn main() -> color_eyre::Result<()> { .route("/login/redirect/:id", get(redirected)) .route("/", get(home)) .route("/mail/delete", post(delete_mail)) + .route("/mail/add", post(add_mail)) .fallback(page_not_found) .with_state(Arc::new(AppState { db, oidc, jwt_secret: config.jwt_secret.0, + mail_domain: config.mail_domain, })); Ok(axum::Server::bind(&addr) diff --git a/templates/home.html b/templates/home.html index 6ae1723..5e85a71 100644 --- a/templates/home.html +++ b/templates/home.html @@ -8,6 +8,7 @@ {% block content %}

Mail management

+ {% if user_error %}
{{ user_error }}
{% endif %}

Mails

    {% for mail in mails %} @@ -43,5 +44,50 @@ {% endfor %}
+ +
+ {% endblock content %}