Save the login information in a cookie
This commit is contained in:
parent
60d9c9b0a2
commit
2dfaf85eea
3 changed files with 336 additions and 24 deletions
92
src/main.rs
92
src/main.rs
|
|
@ -11,7 +11,11 @@ use axum::{
|
|||
routing::get,
|
||||
Router,
|
||||
};
|
||||
use axum_extra::extract::{cookie::Cookie, CookieJar};
|
||||
use base64::{engine::general_purpose, engine::Engine};
|
||||
use color_eyre::eyre;
|
||||
use cookie::{time::OffsetDateTime, SameSite};
|
||||
use jwt_simple::prelude::*;
|
||||
use once_cell::sync::Lazy;
|
||||
use openidconnect::{
|
||||
core::{CoreAuthenticationFlow, CoreClient, CoreProviderMetadata},
|
||||
|
|
@ -33,6 +37,50 @@ fn default_address() -> String {
|
|||
"127.0.0.1".into()
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub(crate) struct Base64(pub(crate) HS256Key);
|
||||
|
||||
impl std::fmt::Debug for Base64 {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
write!(
|
||||
f,
|
||||
r#"b64"{}""#,
|
||||
&general_purpose::STANDARD.encode(self.0.to_bytes())
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
impl<'de> Deserialize<'de> for Base64 {
|
||||
fn deserialize<D>(de: D) -> Result<Self, D::Error>
|
||||
where
|
||||
D: Deserializer<'de>,
|
||||
{
|
||||
use serde::de::Visitor;
|
||||
|
||||
struct DecodingVisitor;
|
||||
impl<'de> Visitor<'de> for DecodingVisitor {
|
||||
type Value = Base64;
|
||||
|
||||
fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
|
||||
formatter.write_str("must be a base 64 string")
|
||||
}
|
||||
|
||||
fn visit_str<E>(self, v: &str) -> Result<Self::Value, E>
|
||||
where
|
||||
E: serde::de::Error,
|
||||
{
|
||||
general_purpose::STANDARD
|
||||
.decode(v)
|
||||
.map_err(E::custom)
|
||||
.map(|b| HS256Key::from_bytes(&b))
|
||||
.map(Base64)
|
||||
}
|
||||
}
|
||||
|
||||
de.deserialize_str(DecodingVisitor)
|
||||
}
|
||||
}
|
||||
|
||||
fn deserialize_comma<'de, D>(de: D) -> Result<Vec<openidconnect::Scope>, D::Error>
|
||||
where
|
||||
D: Deserializer<'de>,
|
||||
|
|
@ -63,6 +111,7 @@ where
|
|||
#[derive(Deserialize, Debug)]
|
||||
#[serde(rename_all = "UPPERCASE")]
|
||||
struct Settings {
|
||||
jwt_secret: Base64,
|
||||
#[serde(default = "default_port")]
|
||||
port: u16,
|
||||
#[serde(default = "default_address")]
|
||||
|
|
@ -262,8 +311,10 @@ impl OpenidConnector {
|
|||
}
|
||||
|
||||
struct AppState {
|
||||
jwt_secret: HS256Key,
|
||||
db: PgPool,
|
||||
oidc: OpenidConnector,
|
||||
domain: String,
|
||||
}
|
||||
|
||||
#[derive(thiserror::Error, Debug)]
|
||||
|
|
@ -272,6 +323,8 @@ enum Error {
|
|||
Db(#[from] sqlx::Error),
|
||||
#[error("An error occured when rendering a template")]
|
||||
Tera(#[from] tera::Error),
|
||||
#[error("A JWT error occured")]
|
||||
Jwt(#[from] jwt_simple::Error),
|
||||
#[error("An internal error occured")]
|
||||
InternalError,
|
||||
}
|
||||
|
|
@ -306,6 +359,10 @@ impl IntoResponse for Error {
|
|||
tracing::error!("Tera error: {e:?}");
|
||||
InternalError.into_response()
|
||||
}
|
||||
Error::Jwt(e) => {
|
||||
tracing::error!("JWT error: {e:?}");
|
||||
InternalError.into_response()
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -326,7 +383,8 @@ async fn redirected(
|
|||
state: State<Arc<AppState>>,
|
||||
Path(id): Path<Uuid>,
|
||||
Query(redirect): Query<OidcRedirectParams>,
|
||||
) -> Result<(), Error> {
|
||||
jar: CookieJar,
|
||||
) -> Result<CookieJar, Error> {
|
||||
match state
|
||||
.oidc
|
||||
.redirected(id, redirect.state, redirect.code)
|
||||
|
|
@ -336,7 +394,7 @@ async fn redirected(
|
|||
let account = sqlx::query!("SELECT id FROM accounts WHERE sub = $1", sub)
|
||||
.fetch_optional(&state.db)
|
||||
.await?;
|
||||
let _id = match account {
|
||||
let id = match account {
|
||||
Some(r) => r.id,
|
||||
None => {
|
||||
let id = Uuid::new_v4();
|
||||
|
|
@ -346,7 +404,28 @@ async fn redirected(
|
|||
id
|
||||
}
|
||||
};
|
||||
Ok(())
|
||||
|
||||
let expire = std::time::Duration::from_secs(3600 * 24 * 31 * 6);
|
||||
|
||||
let mut claims = Claims::create(expire.into());
|
||||
claims.subject = Some(id.to_string());
|
||||
|
||||
let token = state.jwt_secret.authenticate(claims)?;
|
||||
|
||||
let mut cookie = Cookie::named("mail_admin_token");
|
||||
cookie.set_value(token);
|
||||
cookie.set_http_only(true);
|
||||
|
||||
let mut now = OffsetDateTime::now_utc();
|
||||
now += expire;
|
||||
cookie.set_expires(now);
|
||||
cookie.set_same_site(SameSite::Strict);
|
||||
cookie.set_secure(true);
|
||||
cookie.set_path("/");
|
||||
|
||||
let jar = jar.add(cookie);
|
||||
|
||||
Ok(jar)
|
||||
}
|
||||
Err(e) => {
|
||||
tracing::error!("Could not finish OAuth2 flow: {e:?}");
|
||||
|
|
@ -403,7 +482,12 @@ async fn main() -> color_eyre::Result<()> {
|
|||
.route("/login", get(login))
|
||||
.route("/login/redirect/:id", get(redirected))
|
||||
.fallback(page_not_found)
|
||||
.with_state(Arc::new(AppState { db, oidc }));
|
||||
.with_state(Arc::new(AppState {
|
||||
db,
|
||||
oidc,
|
||||
jwt_secret: config.jwt_secret.0,
|
||||
domain: config.domain,
|
||||
}));
|
||||
|
||||
Ok(axum::Server::bind(&addr)
|
||||
.serve(router.into_make_service())
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue