Allow to login with OIDC

This commit is contained in:
traxys 2023-07-27 00:06:36 +02:00
parent f556fec3bb
commit 7a2ff7ad1d
12 changed files with 782 additions and 29 deletions

View file

@ -9,8 +9,9 @@ pub struct Model {
pub id: Uuid,
#[sea_orm(unique)]
pub name: String,
#[sea_orm(column_type = "Binary(BlobSize::Blob(None))")]
pub password: Vec<u8>,
#[sea_orm(column_type = "Binary(BlobSize::Blob(None))", nullable)]
pub password: Option<Vec<u8>>,
pub open_id_subject: Option<String>,
}
#[derive(Copy, Clone, Debug, EnumIter, DeriveRelation)]

View file

@ -1,13 +1,26 @@
use std::{net::SocketAddr, path::PathBuf, sync::Arc};
use std::{
collections::{HashMap, VecDeque},
net::SocketAddr,
path::PathBuf,
sync::Arc,
};
use anyhow::anyhow;
use axum::Router;
use base64::{engine::general_purpose, Engine};
use jwt_simple::prelude::HS256Key;
use migration::{Migrator, MigratorTrait};
use openidconnect::{
core::{CoreAuthenticationFlow, CoreClient, CoreProviderMetadata},
url::Url,
AccessTokenHash, AuthorizationCode, ClientId, ClientSecret, CsrfToken, IssuerUrl, Nonce,
OAuth2TokenResponse, PkceCodeChallenge, PkceCodeVerifier, RedirectUrl, TokenResponse,
};
use sea_orm::{ConnectOptions, Database, DatabaseConnection};
use serde::{Deserialize, Deserializer, Serialize, Serializer};
use tower_http::services::{ServeDir, ServeFile};
use tracing_subscriber::EnvFilter;
use uuid::Uuid;
pub(crate) mod entity;
mod routes;
@ -65,6 +78,44 @@ impl<'de> Deserialize<'de> for Base64 {
}
}
fn deserialize_comma<'de, D>(de: D) -> Result<Vec<openidconnect::Scope>, D::Error>
where
D: Deserializer<'de>,
{
use serde::de::Visitor;
struct CommaVisitor;
impl<'de> Visitor<'de> for CommaVisitor {
type Value = Vec<openidconnect::Scope>;
fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
formatter.write_str("string containg comma separated strings")
}
fn visit_str<E>(self, v: &str) -> Result<Self::Value, E>
where
E: serde::de::Error,
{
Ok(v.split(',')
.map(|v| openidconnect::Scope::new(v.to_string()))
.collect())
}
}
de.deserialize_str(CommaVisitor)
}
#[derive(Deserialize, Debug)]
#[serde(rename_all = "UPPERCASE")]
struct OpenidConnectSettings {
url: String,
id: ClientId,
secret: ClientSecret,
domain: String,
#[serde(deserialize_with = "deserialize_comma")]
scopes: Vec<openidconnect::Scope>,
}
fn default_host() -> String {
"0.0.0.0".into()
}
@ -87,6 +138,8 @@ struct Settings {
serve_app: Option<PathBuf>,
database_url: String,
#[serde(default)]
oidc: Option<OpenidConnectSettings>,
#[serde(default)]
sqlx_logging: bool,
}
@ -102,6 +155,198 @@ impl Settings {
struct AppState {
jwt_secret: Base64,
db: DatabaseConnection,
oidc: Option<OpenidConnector>,
}
struct OpenidConnector {
provider: CoreClient,
scopes: Vec<openidconnect::Scope>,
domain: Url,
inflight: parking_lot::Mutex<FifoMap>,
}
struct FifoMap {
max: usize,
order: VecDeque<Uuid>,
values: HashMap<Uuid, OpenidAuthState>,
}
struct FifoMapInsert<'a> {
pub id: Uuid,
map: &'a mut FifoMap,
}
impl<'a> FifoMapInsert<'a> {
pub fn insert(self, state: OpenidAuthState) {
let FifoMapInsert { id, map } = self;
map.order.push_back(id);
if map.order.len() > map.max {
let first = map.order.pop_front().unwrap();
map.values.remove(&first);
}
map.values.insert(id, state);
}
}
impl FifoMap {
pub fn new(max: usize) -> Self {
Self {
max,
order: VecDeque::new(),
values: HashMap::new(),
}
}
pub fn new_entry(&mut self) -> FifoMapInsert {
let id = loop {
let id = Uuid::new_v4();
if !self.values.contains_key(&id) {
break id;
}
};
FifoMapInsert { id, map: self }
}
pub fn remove(&mut self, id: Uuid) -> Option<OpenidAuthState> {
let state = self.values.remove(&id);
if let Some(idx) = self.order.iter().position(|&v| v == id) {
self.order.remove(idx);
};
state
}
}
struct OpenidAuthState {
pkce_verifier: PkceCodeVerifier,
csrf_token: CsrfToken,
nonce: Nonce,
source_url: String,
}
pub struct OpenidAccount {
pub sub: String,
pub name: String,
pub source_url: String,
}
impl OpenidConnector {
async fn new(settings: OpenidConnectSettings) -> anyhow::Result<Self> {
let metadata = CoreProviderMetadata::discover_async(
IssuerUrl::new(settings.url)?,
openidconnect::reqwest::async_http_client,
)
.await?;
let provider =
CoreClient::from_provider_metadata(metadata, settings.id, Some(settings.secret));
Ok(Self {
provider,
scopes: settings.scopes,
domain: settings.domain.parse()?,
inflight: parking_lot::Mutex::new(FifoMap::new(1024)),
})
}
pub fn start_auth(&self, source_url: String) -> Url {
let (pkce_challenge, pkce_verifier) = PkceCodeChallenge::new_random_sha256();
let mut inflight = self.inflight.lock();
let slot = inflight.new_entry();
tracing::info!("Login flow with id {}", slot.id);
let (url, csrf_token, nonce) = self
.provider
.authorize_url(
CoreAuthenticationFlow::AuthorizationCode,
CsrfToken::new_random,
Nonce::new_random,
)
.add_scopes(self.scopes.iter().cloned())
.set_pkce_challenge(pkce_challenge)
.set_redirect_uri(std::borrow::Cow::Owned(RedirectUrl::from_url(
self.domain
.join(&format!("api/login/redirect/{}", slot.id))
.unwrap(),
)))
.url();
slot.insert(OpenidAuthState {
pkce_verifier,
csrf_token,
nonce,
source_url,
});
url
}
pub async fn redirected(
&self,
id: Uuid,
csrf_state: String,
code: String,
) -> anyhow::Result<OpenidAccount> {
let Some(state) = self.inflight.lock().remove(id) else {
anyhow::bail!("Redirect has expired or has already been submitted")
};
tracing::info!("Redirected to continue login flow for {id}");
if state.csrf_token.secret() != &csrf_state {
anyhow::bail!(
"CRSF state does not match, expected: {} got {csrf_state}",
state.csrf_token.secret()
)
}
let token_response = self
.provider
.exchange_code(AuthorizationCode::new(code))
.set_pkce_verifier(state.pkce_verifier)
.set_redirect_uri(std::borrow::Cow::Owned(RedirectUrl::from_url(
self.domain
.join(&format!("api/login/redirect/{}", id))
.unwrap(),
)))
.request_async(openidconnect::reqwest::async_http_client)
.await?;
let Some(id_token) = token_response.id_token() else {
anyhow::bail!("Server did not return an ID token")
};
let claims = id_token.claims(&self.provider.id_token_verifier(), &state.nonce)?;
if let Some(expected_access_token_hash) = claims.access_token_hash() {
let actual_access_token_hash = AccessTokenHash::from_token(
token_response.access_token(),
&id_token.signing_alg()?,
)?;
if actual_access_token_hash != *expected_access_token_hash {
anyhow::bail!("Invalid access token")
}
}
Ok(OpenidAccount {
sub: claims.subject().to_string(),
name: claims
.name()
.and_then(|v| v.get(None))
.map(|v| v.to_string())
.or(claims.preferred_username().map(|v| v.to_string()))
.ok_or_else(|| anyhow!("No name or preferred_username"))?,
source_url: state.source_url,
})
}
}
#[tokio::main]
@ -121,9 +366,15 @@ async fn main() -> anyhow::Result<()> {
let mut opt = ConnectOptions::new(config.database_url);
opt.sqlx_logging(config.sqlx_logging);
let oidc = match config.oidc {
None => None,
Some(settings) => Some(OpenidConnector::new(settings).await?),
};
let state = Arc::new(AppState {
jwt_secret: config.jwt_secret,
db: Database::connect(opt).await?,
oidc,
});
Migrator::up(&state.db, None).await?;
@ -131,7 +382,10 @@ async fn main() -> anyhow::Result<()> {
let router = Router::new()
.nest(
"/api",
routes::router(config.api_allowed.map(|s| s.parse()).transpose()?),
routes::router(
config.api_allowed.map(|s| s.parse()).transpose()?,
state.oidc.is_some(),
),
)
.with_state(state);

View file

@ -3,19 +3,19 @@ use std::sync::Arc;
use api::{LoginRequest, LoginResponse, UserInfo};
use axum::{
async_trait,
extract::{FromRef, FromRequestParts, Path, State},
extract::{FromRef, FromRequestParts, Path, Query, State},
headers::{authorization::Bearer, Authorization},
http::{
header::{AUTHORIZATION, CONTENT_TYPE},
request::Parts,
HeaderValue, Method, StatusCode,
},
response::IntoResponse,
response::{IntoResponse, Redirect},
routing::{delete, get, patch, post, put},
Json, Router, TypedHeader,
};
use jwt_simple::prelude::*;
use sea_orm::{prelude::*, TransactionError};
use sea_orm::{prelude::*, ActiveValue, TransactionError};
use sha2::{Digest, Sha512};
use tower_http::cors::{self, AllowOrigin, CorsLayer};
@ -45,6 +45,8 @@ enum RouteError {
RessourceNotFound,
#[error("The request was malformed")]
InvalidRequest(String),
#[error("A normal account with this name already exists")]
NormalAccount,
#[error("Error in DB transaction")]
TxnError(#[from] TransactionError<Box<RouteError>>),
}
@ -77,6 +79,9 @@ impl IntoResponse for RouteError {
.into_response(),
RouteError::RessourceNotFound => StatusCode::NOT_FOUND.into_response(),
RouteError::InvalidRequest(reason) => (StatusCode::BAD_REQUEST, reason).into_response(),
s @ RouteError::NormalAccount => {
(StatusCode::BAD_REQUEST, s.to_string()).into_response()
}
e => {
tracing::error!("Internal error: {e:?}");
StatusCode::INTERNAL_SERVER_ERROR.into_response()
@ -138,12 +143,16 @@ async fn login(
return Err(RouteError::UnknownAccount)
};
let Some(password) = user.password.as_ref() else {
return Err(RouteError::UnknownAccount);
};
let mut hasher = Sha512::new();
hasher.update(user.id.as_bytes());
hasher.update(req.password.as_bytes());
let hash = hasher.finalize();
if hash[..] != user.password {
if &hash[..] != password {
return Err(RouteError::UnknownAccount);
}
@ -155,6 +164,92 @@ async fn login(
Ok(Json(LoginResponse { token }))
}
#[derive(Deserialize)]
struct OidcStartParam {
r#return: String,
}
async fn oidc_login(
State(state): State<AppState>,
Query(start): Query<OidcStartParam>,
) -> Result<Redirect, RouteError> {
tracing::info!("Starting OIDC login");
let oidc = state.oidc.as_ref().unwrap();
let redirect_url = oidc.start_auth(start.r#return);
Ok(Redirect::to(redirect_url.as_str()))
}
#[derive(Deserialize)]
struct OidcRedirectParams {
state: String,
code: String,
}
async fn oidc_login_finish(
State(state): State<AppState>,
Path(id): Path<Uuid>,
Query(redirect): Query<OidcRedirectParams>,
) -> Result<Redirect, RouteError> {
match state
.oidc
.as_ref()
.unwrap()
.redirected(id, redirect.state, redirect.code)
.await
{
Err(e) => {
tracing::error!("Error when finishing OAuth2 flow {e:?}");
Err(RouteError::Unauthorized)
}
Ok(account) => {
let user = User::find()
.filter(
user::Column::Name
.eq(&account.name)
.or(user::Column::OpenIdSubject.eq(&account.sub)),
)
.one(&state.db)
.await?;
let user = match user {
None => {
let model = user::ActiveModel {
id: ActiveValue::Set(Uuid::new_v4()),
name: ActiveValue::Set(account.name),
password: ActiveValue::NotSet,
open_id_subject: ActiveValue::Set(Some(account.sub)),
};
model.insert(&state.db).await?
}
Some(user) => {
if user.open_id_subject.as_ref() != Some(&account.sub) {
return Err(RouteError::NormalAccount);
}
user
}
};
let mut claims = Claims::create(Duration::from_secs(3600 * 24 * 31 * 6));
claims.subject = Some(user.id.to_string());
let token = state.jwt_secret.0.authenticate(claims)?;
let redirect = format!(
"{}?token={}&username={}",
account.source_url,
urlencoding::encode(&token),
urlencoding::encode(&user.name),
);
Ok(Redirect::to(&redirect))
}
}
}
async fn get_user_id(
_: AuthenticatedUser,
State(state): State<AppState>,
@ -173,7 +268,7 @@ async fn get_user_id(
})))
}
pub(crate) fn router(api_allowed: Option<HeaderValue>) -> Router<AppState> {
pub(crate) fn router(api_allowed: Option<HeaderValue>, has_oidc: bool) -> Router<AppState> {
let origin: AllowOrigin = match api_allowed {
Some(n) => n.into(),
None => cors::Any.into(),
@ -185,7 +280,7 @@ pub(crate) fn router(api_allowed: Option<HeaderValue>) -> Router<AppState> {
let mk_service = |m: Vec<Method>| cors_base.clone().allow_methods(m);
Router::new()
let router = Router::new()
.route(
"/search/user/:name",
get(get_user_id).layer(mk_service(vec![Method::GET])),
@ -246,5 +341,18 @@ pub(crate) fn router(api_allowed: Option<HeaderValue>) -> Router<AppState> {
.delete(recipe::delete_ig)
.put(recipe::add_ig_request)
.layer(mk_service(vec![Method::PATCH, Method::DELETE, Method::PUT])),
)
);
if has_oidc {
async fn unit() {}
router
.route("/login/oidc", get(oidc_login))
.route(
"/login/has_oidc",
get(unit).layer(mk_service(vec![Method::GET])),
)
.route("/login/redirect/:id", get(oidc_login_finish))
} else {
router
}
}