Allow to login with OIDC
This commit is contained in:
parent
f556fec3bb
commit
7a2ff7ad1d
12 changed files with 782 additions and 29 deletions
|
|
@ -9,8 +9,9 @@ pub struct Model {
|
|||
pub id: Uuid,
|
||||
#[sea_orm(unique)]
|
||||
pub name: String,
|
||||
#[sea_orm(column_type = "Binary(BlobSize::Blob(None))")]
|
||||
pub password: Vec<u8>,
|
||||
#[sea_orm(column_type = "Binary(BlobSize::Blob(None))", nullable)]
|
||||
pub password: Option<Vec<u8>>,
|
||||
pub open_id_subject: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Copy, Clone, Debug, EnumIter, DeriveRelation)]
|
||||
|
|
|
|||
258
src/main.rs
258
src/main.rs
|
|
@ -1,13 +1,26 @@
|
|||
use std::{net::SocketAddr, path::PathBuf, sync::Arc};
|
||||
use std::{
|
||||
collections::{HashMap, VecDeque},
|
||||
net::SocketAddr,
|
||||
path::PathBuf,
|
||||
sync::Arc,
|
||||
};
|
||||
|
||||
use anyhow::anyhow;
|
||||
use axum::Router;
|
||||
use base64::{engine::general_purpose, Engine};
|
||||
use jwt_simple::prelude::HS256Key;
|
||||
use migration::{Migrator, MigratorTrait};
|
||||
use openidconnect::{
|
||||
core::{CoreAuthenticationFlow, CoreClient, CoreProviderMetadata},
|
||||
url::Url,
|
||||
AccessTokenHash, AuthorizationCode, ClientId, ClientSecret, CsrfToken, IssuerUrl, Nonce,
|
||||
OAuth2TokenResponse, PkceCodeChallenge, PkceCodeVerifier, RedirectUrl, TokenResponse,
|
||||
};
|
||||
use sea_orm::{ConnectOptions, Database, DatabaseConnection};
|
||||
use serde::{Deserialize, Deserializer, Serialize, Serializer};
|
||||
use tower_http::services::{ServeDir, ServeFile};
|
||||
use tracing_subscriber::EnvFilter;
|
||||
use uuid::Uuid;
|
||||
|
||||
pub(crate) mod entity;
|
||||
mod routes;
|
||||
|
|
@ -65,6 +78,44 @@ impl<'de> Deserialize<'de> for Base64 {
|
|||
}
|
||||
}
|
||||
|
||||
fn deserialize_comma<'de, D>(de: D) -> Result<Vec<openidconnect::Scope>, D::Error>
|
||||
where
|
||||
D: Deserializer<'de>,
|
||||
{
|
||||
use serde::de::Visitor;
|
||||
|
||||
struct CommaVisitor;
|
||||
impl<'de> Visitor<'de> for CommaVisitor {
|
||||
type Value = Vec<openidconnect::Scope>;
|
||||
|
||||
fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
|
||||
formatter.write_str("string containg comma separated strings")
|
||||
}
|
||||
|
||||
fn visit_str<E>(self, v: &str) -> Result<Self::Value, E>
|
||||
where
|
||||
E: serde::de::Error,
|
||||
{
|
||||
Ok(v.split(',')
|
||||
.map(|v| openidconnect::Scope::new(v.to_string()))
|
||||
.collect())
|
||||
}
|
||||
}
|
||||
|
||||
de.deserialize_str(CommaVisitor)
|
||||
}
|
||||
|
||||
#[derive(Deserialize, Debug)]
|
||||
#[serde(rename_all = "UPPERCASE")]
|
||||
struct OpenidConnectSettings {
|
||||
url: String,
|
||||
id: ClientId,
|
||||
secret: ClientSecret,
|
||||
domain: String,
|
||||
#[serde(deserialize_with = "deserialize_comma")]
|
||||
scopes: Vec<openidconnect::Scope>,
|
||||
}
|
||||
|
||||
fn default_host() -> String {
|
||||
"0.0.0.0".into()
|
||||
}
|
||||
|
|
@ -87,6 +138,8 @@ struct Settings {
|
|||
serve_app: Option<PathBuf>,
|
||||
database_url: String,
|
||||
#[serde(default)]
|
||||
oidc: Option<OpenidConnectSettings>,
|
||||
#[serde(default)]
|
||||
sqlx_logging: bool,
|
||||
}
|
||||
|
||||
|
|
@ -102,6 +155,198 @@ impl Settings {
|
|||
struct AppState {
|
||||
jwt_secret: Base64,
|
||||
db: DatabaseConnection,
|
||||
oidc: Option<OpenidConnector>,
|
||||
}
|
||||
|
||||
struct OpenidConnector {
|
||||
provider: CoreClient,
|
||||
scopes: Vec<openidconnect::Scope>,
|
||||
domain: Url,
|
||||
inflight: parking_lot::Mutex<FifoMap>,
|
||||
}
|
||||
|
||||
struct FifoMap {
|
||||
max: usize,
|
||||
order: VecDeque<Uuid>,
|
||||
values: HashMap<Uuid, OpenidAuthState>,
|
||||
}
|
||||
|
||||
struct FifoMapInsert<'a> {
|
||||
pub id: Uuid,
|
||||
map: &'a mut FifoMap,
|
||||
}
|
||||
|
||||
impl<'a> FifoMapInsert<'a> {
|
||||
pub fn insert(self, state: OpenidAuthState) {
|
||||
let FifoMapInsert { id, map } = self;
|
||||
|
||||
map.order.push_back(id);
|
||||
if map.order.len() > map.max {
|
||||
let first = map.order.pop_front().unwrap();
|
||||
map.values.remove(&first);
|
||||
}
|
||||
|
||||
map.values.insert(id, state);
|
||||
}
|
||||
}
|
||||
|
||||
impl FifoMap {
|
||||
pub fn new(max: usize) -> Self {
|
||||
Self {
|
||||
max,
|
||||
order: VecDeque::new(),
|
||||
values: HashMap::new(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn new_entry(&mut self) -> FifoMapInsert {
|
||||
let id = loop {
|
||||
let id = Uuid::new_v4();
|
||||
if !self.values.contains_key(&id) {
|
||||
break id;
|
||||
}
|
||||
};
|
||||
|
||||
FifoMapInsert { id, map: self }
|
||||
}
|
||||
|
||||
pub fn remove(&mut self, id: Uuid) -> Option<OpenidAuthState> {
|
||||
let state = self.values.remove(&id);
|
||||
|
||||
if let Some(idx) = self.order.iter().position(|&v| v == id) {
|
||||
self.order.remove(idx);
|
||||
};
|
||||
|
||||
state
|
||||
}
|
||||
}
|
||||
|
||||
struct OpenidAuthState {
|
||||
pkce_verifier: PkceCodeVerifier,
|
||||
csrf_token: CsrfToken,
|
||||
nonce: Nonce,
|
||||
source_url: String,
|
||||
}
|
||||
|
||||
pub struct OpenidAccount {
|
||||
pub sub: String,
|
||||
pub name: String,
|
||||
pub source_url: String,
|
||||
}
|
||||
|
||||
impl OpenidConnector {
|
||||
async fn new(settings: OpenidConnectSettings) -> anyhow::Result<Self> {
|
||||
let metadata = CoreProviderMetadata::discover_async(
|
||||
IssuerUrl::new(settings.url)?,
|
||||
openidconnect::reqwest::async_http_client,
|
||||
)
|
||||
.await?;
|
||||
|
||||
let provider =
|
||||
CoreClient::from_provider_metadata(metadata, settings.id, Some(settings.secret));
|
||||
|
||||
Ok(Self {
|
||||
provider,
|
||||
scopes: settings.scopes,
|
||||
domain: settings.domain.parse()?,
|
||||
inflight: parking_lot::Mutex::new(FifoMap::new(1024)),
|
||||
})
|
||||
}
|
||||
|
||||
pub fn start_auth(&self, source_url: String) -> Url {
|
||||
let (pkce_challenge, pkce_verifier) = PkceCodeChallenge::new_random_sha256();
|
||||
|
||||
let mut inflight = self.inflight.lock();
|
||||
|
||||
let slot = inflight.new_entry();
|
||||
|
||||
tracing::info!("Login flow with id {}", slot.id);
|
||||
|
||||
let (url, csrf_token, nonce) = self
|
||||
.provider
|
||||
.authorize_url(
|
||||
CoreAuthenticationFlow::AuthorizationCode,
|
||||
CsrfToken::new_random,
|
||||
Nonce::new_random,
|
||||
)
|
||||
.add_scopes(self.scopes.iter().cloned())
|
||||
.set_pkce_challenge(pkce_challenge)
|
||||
.set_redirect_uri(std::borrow::Cow::Owned(RedirectUrl::from_url(
|
||||
self.domain
|
||||
.join(&format!("api/login/redirect/{}", slot.id))
|
||||
.unwrap(),
|
||||
)))
|
||||
.url();
|
||||
|
||||
slot.insert(OpenidAuthState {
|
||||
pkce_verifier,
|
||||
csrf_token,
|
||||
nonce,
|
||||
source_url,
|
||||
});
|
||||
|
||||
url
|
||||
}
|
||||
|
||||
pub async fn redirected(
|
||||
&self,
|
||||
id: Uuid,
|
||||
csrf_state: String,
|
||||
code: String,
|
||||
) -> anyhow::Result<OpenidAccount> {
|
||||
let Some(state) = self.inflight.lock().remove(id) else {
|
||||
anyhow::bail!("Redirect has expired or has already been submitted")
|
||||
};
|
||||
|
||||
tracing::info!("Redirected to continue login flow for {id}");
|
||||
|
||||
if state.csrf_token.secret() != &csrf_state {
|
||||
anyhow::bail!(
|
||||
"CRSF state does not match, expected: {} got {csrf_state}",
|
||||
state.csrf_token.secret()
|
||||
)
|
||||
}
|
||||
|
||||
let token_response = self
|
||||
.provider
|
||||
.exchange_code(AuthorizationCode::new(code))
|
||||
.set_pkce_verifier(state.pkce_verifier)
|
||||
.set_redirect_uri(std::borrow::Cow::Owned(RedirectUrl::from_url(
|
||||
self.domain
|
||||
.join(&format!("api/login/redirect/{}", id))
|
||||
.unwrap(),
|
||||
)))
|
||||
.request_async(openidconnect::reqwest::async_http_client)
|
||||
.await?;
|
||||
|
||||
let Some(id_token) = token_response.id_token() else {
|
||||
anyhow::bail!("Server did not return an ID token")
|
||||
};
|
||||
|
||||
let claims = id_token.claims(&self.provider.id_token_verifier(), &state.nonce)?;
|
||||
|
||||
if let Some(expected_access_token_hash) = claims.access_token_hash() {
|
||||
let actual_access_token_hash = AccessTokenHash::from_token(
|
||||
token_response.access_token(),
|
||||
&id_token.signing_alg()?,
|
||||
)?;
|
||||
|
||||
if actual_access_token_hash != *expected_access_token_hash {
|
||||
anyhow::bail!("Invalid access token")
|
||||
}
|
||||
}
|
||||
|
||||
Ok(OpenidAccount {
|
||||
sub: claims.subject().to_string(),
|
||||
name: claims
|
||||
.name()
|
||||
.and_then(|v| v.get(None))
|
||||
.map(|v| v.to_string())
|
||||
.or(claims.preferred_username().map(|v| v.to_string()))
|
||||
.ok_or_else(|| anyhow!("No name or preferred_username"))?,
|
||||
source_url: state.source_url,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::main]
|
||||
|
|
@ -121,9 +366,15 @@ async fn main() -> anyhow::Result<()> {
|
|||
let mut opt = ConnectOptions::new(config.database_url);
|
||||
opt.sqlx_logging(config.sqlx_logging);
|
||||
|
||||
let oidc = match config.oidc {
|
||||
None => None,
|
||||
Some(settings) => Some(OpenidConnector::new(settings).await?),
|
||||
};
|
||||
|
||||
let state = Arc::new(AppState {
|
||||
jwt_secret: config.jwt_secret,
|
||||
db: Database::connect(opt).await?,
|
||||
oidc,
|
||||
});
|
||||
|
||||
Migrator::up(&state.db, None).await?;
|
||||
|
|
@ -131,7 +382,10 @@ async fn main() -> anyhow::Result<()> {
|
|||
let router = Router::new()
|
||||
.nest(
|
||||
"/api",
|
||||
routes::router(config.api_allowed.map(|s| s.parse()).transpose()?),
|
||||
routes::router(
|
||||
config.api_allowed.map(|s| s.parse()).transpose()?,
|
||||
state.oidc.is_some(),
|
||||
),
|
||||
)
|
||||
.with_state(state);
|
||||
|
||||
|
|
|
|||
|
|
@ -3,19 +3,19 @@ use std::sync::Arc;
|
|||
use api::{LoginRequest, LoginResponse, UserInfo};
|
||||
use axum::{
|
||||
async_trait,
|
||||
extract::{FromRef, FromRequestParts, Path, State},
|
||||
extract::{FromRef, FromRequestParts, Path, Query, State},
|
||||
headers::{authorization::Bearer, Authorization},
|
||||
http::{
|
||||
header::{AUTHORIZATION, CONTENT_TYPE},
|
||||
request::Parts,
|
||||
HeaderValue, Method, StatusCode,
|
||||
},
|
||||
response::IntoResponse,
|
||||
response::{IntoResponse, Redirect},
|
||||
routing::{delete, get, patch, post, put},
|
||||
Json, Router, TypedHeader,
|
||||
};
|
||||
use jwt_simple::prelude::*;
|
||||
use sea_orm::{prelude::*, TransactionError};
|
||||
use sea_orm::{prelude::*, ActiveValue, TransactionError};
|
||||
use sha2::{Digest, Sha512};
|
||||
use tower_http::cors::{self, AllowOrigin, CorsLayer};
|
||||
|
||||
|
|
@ -45,6 +45,8 @@ enum RouteError {
|
|||
RessourceNotFound,
|
||||
#[error("The request was malformed")]
|
||||
InvalidRequest(String),
|
||||
#[error("A normal account with this name already exists")]
|
||||
NormalAccount,
|
||||
#[error("Error in DB transaction")]
|
||||
TxnError(#[from] TransactionError<Box<RouteError>>),
|
||||
}
|
||||
|
|
@ -77,6 +79,9 @@ impl IntoResponse for RouteError {
|
|||
.into_response(),
|
||||
RouteError::RessourceNotFound => StatusCode::NOT_FOUND.into_response(),
|
||||
RouteError::InvalidRequest(reason) => (StatusCode::BAD_REQUEST, reason).into_response(),
|
||||
s @ RouteError::NormalAccount => {
|
||||
(StatusCode::BAD_REQUEST, s.to_string()).into_response()
|
||||
}
|
||||
e => {
|
||||
tracing::error!("Internal error: {e:?}");
|
||||
StatusCode::INTERNAL_SERVER_ERROR.into_response()
|
||||
|
|
@ -138,12 +143,16 @@ async fn login(
|
|||
return Err(RouteError::UnknownAccount)
|
||||
};
|
||||
|
||||
let Some(password) = user.password.as_ref() else {
|
||||
return Err(RouteError::UnknownAccount);
|
||||
};
|
||||
|
||||
let mut hasher = Sha512::new();
|
||||
hasher.update(user.id.as_bytes());
|
||||
hasher.update(req.password.as_bytes());
|
||||
let hash = hasher.finalize();
|
||||
|
||||
if hash[..] != user.password {
|
||||
if &hash[..] != password {
|
||||
return Err(RouteError::UnknownAccount);
|
||||
}
|
||||
|
||||
|
|
@ -155,6 +164,92 @@ async fn login(
|
|||
Ok(Json(LoginResponse { token }))
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
struct OidcStartParam {
|
||||
r#return: String,
|
||||
}
|
||||
|
||||
async fn oidc_login(
|
||||
State(state): State<AppState>,
|
||||
Query(start): Query<OidcStartParam>,
|
||||
) -> Result<Redirect, RouteError> {
|
||||
tracing::info!("Starting OIDC login");
|
||||
let oidc = state.oidc.as_ref().unwrap();
|
||||
|
||||
let redirect_url = oidc.start_auth(start.r#return);
|
||||
|
||||
Ok(Redirect::to(redirect_url.as_str()))
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
struct OidcRedirectParams {
|
||||
state: String,
|
||||
code: String,
|
||||
}
|
||||
|
||||
async fn oidc_login_finish(
|
||||
State(state): State<AppState>,
|
||||
Path(id): Path<Uuid>,
|
||||
Query(redirect): Query<OidcRedirectParams>,
|
||||
) -> Result<Redirect, RouteError> {
|
||||
match state
|
||||
.oidc
|
||||
.as_ref()
|
||||
.unwrap()
|
||||
.redirected(id, redirect.state, redirect.code)
|
||||
.await
|
||||
{
|
||||
Err(e) => {
|
||||
tracing::error!("Error when finishing OAuth2 flow {e:?}");
|
||||
Err(RouteError::Unauthorized)
|
||||
}
|
||||
Ok(account) => {
|
||||
let user = User::find()
|
||||
.filter(
|
||||
user::Column::Name
|
||||
.eq(&account.name)
|
||||
.or(user::Column::OpenIdSubject.eq(&account.sub)),
|
||||
)
|
||||
.one(&state.db)
|
||||
.await?;
|
||||
|
||||
let user = match user {
|
||||
None => {
|
||||
let model = user::ActiveModel {
|
||||
id: ActiveValue::Set(Uuid::new_v4()),
|
||||
name: ActiveValue::Set(account.name),
|
||||
password: ActiveValue::NotSet,
|
||||
open_id_subject: ActiveValue::Set(Some(account.sub)),
|
||||
};
|
||||
|
||||
model.insert(&state.db).await?
|
||||
}
|
||||
Some(user) => {
|
||||
if user.open_id_subject.as_ref() != Some(&account.sub) {
|
||||
return Err(RouteError::NormalAccount);
|
||||
}
|
||||
|
||||
user
|
||||
}
|
||||
};
|
||||
|
||||
let mut claims = Claims::create(Duration::from_secs(3600 * 24 * 31 * 6));
|
||||
claims.subject = Some(user.id.to_string());
|
||||
|
||||
let token = state.jwt_secret.0.authenticate(claims)?;
|
||||
|
||||
let redirect = format!(
|
||||
"{}?token={}&username={}",
|
||||
account.source_url,
|
||||
urlencoding::encode(&token),
|
||||
urlencoding::encode(&user.name),
|
||||
);
|
||||
|
||||
Ok(Redirect::to(&redirect))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async fn get_user_id(
|
||||
_: AuthenticatedUser,
|
||||
State(state): State<AppState>,
|
||||
|
|
@ -173,7 +268,7 @@ async fn get_user_id(
|
|||
})))
|
||||
}
|
||||
|
||||
pub(crate) fn router(api_allowed: Option<HeaderValue>) -> Router<AppState> {
|
||||
pub(crate) fn router(api_allowed: Option<HeaderValue>, has_oidc: bool) -> Router<AppState> {
|
||||
let origin: AllowOrigin = match api_allowed {
|
||||
Some(n) => n.into(),
|
||||
None => cors::Any.into(),
|
||||
|
|
@ -185,7 +280,7 @@ pub(crate) fn router(api_allowed: Option<HeaderValue>) -> Router<AppState> {
|
|||
|
||||
let mk_service = |m: Vec<Method>| cors_base.clone().allow_methods(m);
|
||||
|
||||
Router::new()
|
||||
let router = Router::new()
|
||||
.route(
|
||||
"/search/user/:name",
|
||||
get(get_user_id).layer(mk_service(vec![Method::GET])),
|
||||
|
|
@ -246,5 +341,18 @@ pub(crate) fn router(api_allowed: Option<HeaderValue>) -> Router<AppState> {
|
|||
.delete(recipe::delete_ig)
|
||||
.put(recipe::add_ig_request)
|
||||
.layer(mk_service(vec![Method::PATCH, Method::DELETE, Method::PUT])),
|
||||
)
|
||||
);
|
||||
|
||||
if has_oidc {
|
||||
async fn unit() {}
|
||||
router
|
||||
.route("/login/oidc", get(oidc_login))
|
||||
.route(
|
||||
"/login/has_oidc",
|
||||
get(unit).layer(mk_service(vec![Method::GET])),
|
||||
)
|
||||
.route("/login/redirect/:id", get(oidc_login_finish))
|
||||
} else {
|
||||
router
|
||||
}
|
||||
}
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue