Save the login information in a cookie

This commit is contained in:
traxys 2023-08-16 21:16:30 +02:00
parent 60d9c9b0a2
commit 2dfaf85eea
3 changed files with 336 additions and 24 deletions

View file

@ -11,7 +11,11 @@ use axum::{
routing::get,
Router,
};
use axum_extra::extract::{cookie::Cookie, CookieJar};
use base64::{engine::general_purpose, engine::Engine};
use color_eyre::eyre;
use cookie::{time::OffsetDateTime, SameSite};
use jwt_simple::prelude::*;
use once_cell::sync::Lazy;
use openidconnect::{
core::{CoreAuthenticationFlow, CoreClient, CoreProviderMetadata},
@ -33,6 +37,50 @@ fn default_address() -> String {
"127.0.0.1".into()
}
#[derive(Clone)]
pub(crate) struct Base64(pub(crate) HS256Key);
impl std::fmt::Debug for Base64 {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(
f,
r#"b64"{}""#,
&general_purpose::STANDARD.encode(self.0.to_bytes())
)
}
}
impl<'de> Deserialize<'de> for Base64 {
fn deserialize<D>(de: D) -> Result<Self, D::Error>
where
D: Deserializer<'de>,
{
use serde::de::Visitor;
struct DecodingVisitor;
impl<'de> Visitor<'de> for DecodingVisitor {
type Value = Base64;
fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
formatter.write_str("must be a base 64 string")
}
fn visit_str<E>(self, v: &str) -> Result<Self::Value, E>
where
E: serde::de::Error,
{
general_purpose::STANDARD
.decode(v)
.map_err(E::custom)
.map(|b| HS256Key::from_bytes(&b))
.map(Base64)
}
}
de.deserialize_str(DecodingVisitor)
}
}
fn deserialize_comma<'de, D>(de: D) -> Result<Vec<openidconnect::Scope>, D::Error>
where
D: Deserializer<'de>,
@ -63,6 +111,7 @@ where
#[derive(Deserialize, Debug)]
#[serde(rename_all = "UPPERCASE")]
struct Settings {
jwt_secret: Base64,
#[serde(default = "default_port")]
port: u16,
#[serde(default = "default_address")]
@ -262,8 +311,10 @@ impl OpenidConnector {
}
struct AppState {
jwt_secret: HS256Key,
db: PgPool,
oidc: OpenidConnector,
domain: String,
}
#[derive(thiserror::Error, Debug)]
@ -272,6 +323,8 @@ enum Error {
Db(#[from] sqlx::Error),
#[error("An error occured when rendering a template")]
Tera(#[from] tera::Error),
#[error("A JWT error occured")]
Jwt(#[from] jwt_simple::Error),
#[error("An internal error occured")]
InternalError,
}
@ -306,6 +359,10 @@ impl IntoResponse for Error {
tracing::error!("Tera error: {e:?}");
InternalError.into_response()
}
Error::Jwt(e) => {
tracing::error!("JWT error: {e:?}");
InternalError.into_response()
}
}
}
}
@ -326,7 +383,8 @@ async fn redirected(
state: State<Arc<AppState>>,
Path(id): Path<Uuid>,
Query(redirect): Query<OidcRedirectParams>,
) -> Result<(), Error> {
jar: CookieJar,
) -> Result<CookieJar, Error> {
match state
.oidc
.redirected(id, redirect.state, redirect.code)
@ -336,7 +394,7 @@ async fn redirected(
let account = sqlx::query!("SELECT id FROM accounts WHERE sub = $1", sub)
.fetch_optional(&state.db)
.await?;
let _id = match account {
let id = match account {
Some(r) => r.id,
None => {
let id = Uuid::new_v4();
@ -346,7 +404,28 @@ async fn redirected(
id
}
};
Ok(())
let expire = std::time::Duration::from_secs(3600 * 24 * 31 * 6);
let mut claims = Claims::create(expire.into());
claims.subject = Some(id.to_string());
let token = state.jwt_secret.authenticate(claims)?;
let mut cookie = Cookie::named("mail_admin_token");
cookie.set_value(token);
cookie.set_http_only(true);
let mut now = OffsetDateTime::now_utc();
now += expire;
cookie.set_expires(now);
cookie.set_same_site(SameSite::Strict);
cookie.set_secure(true);
cookie.set_path("/");
let jar = jar.add(cookie);
Ok(jar)
}
Err(e) => {
tracing::error!("Could not finish OAuth2 flow: {e:?}");
@ -403,7 +482,12 @@ async fn main() -> color_eyre::Result<()> {
.route("/login", get(login))
.route("/login/redirect/:id", get(redirected))
.fallback(page_not_found)
.with_state(Arc::new(AppState { db, oidc }));
.with_state(Arc::new(AppState {
db,
oidc,
jwt_secret: config.jwt_secret.0,
domain: config.domain,
}));
Ok(axum::Server::bind(&addr)
.serve(router.into_make_service())