update: oauth state more understandable

This commit is contained in:
maix0 2024-09-28 12:19:02 +02:00
parent e151a799ef
commit 180503a244
3 changed files with 47 additions and 48 deletions

View file

@ -6,7 +6,7 @@
use std::{
collections::{HashMap, HashSet},
sync::{Arc, RwLock},
sync::Arc,
time::Duration,
};
@ -14,21 +14,21 @@ use axum::{
async_trait,
extract::{FromRef, FromRequestParts, Query, State},
http::{request::Parts, StatusCode},
response::{AppendHeaders, Html, IntoResponse, Redirect},
response::{Html, IntoResponse, Redirect},
routing::get,
Json, Router,
Router,
};
use axum_extra::extract::{
cookie::{Cookie, Expiration, Key, SameSite},
cookie::{Cookie, Key, SameSite},
CookieJar, PrivateCookieJar,
};
use base64::Engine;
use color_eyre::eyre::Context;
use reqwest::tls::Version;
use serde::{de::DeserializeOwned, Deserialize, Serialize};
use serde_json::{json, Value};
use tokio::{io::AsyncReadExt, sync::Mutex};
use tracing::{error, info};
use serde::{Deserialize, Serialize};
use serde_json::json;
use tokio::sync::Mutex;
use tracing::{error, info, warn};
macro_rules! unwrap_env {
($name:literal) => {
@ -78,6 +78,7 @@ async fn tutors(config: AppState) {
"page[number]": page_nb,
"page[size]": 100,
}),
Option::<&oauth2::Token>::None,
)
.await
.unwrap();
@ -119,7 +120,7 @@ async fn main() {
http.clone(),
unwrap_env!("CLIENT_ID"),
unwrap_env!("CLIENT_SECRET"),
"http://local.maix.me/auth/callback",
"http://local.maix.me:9911/auth/callback",
)
.await
.unwrap();
@ -147,7 +148,7 @@ async fn main() {
// run our app with hyper
let listener = tokio::net::TcpListener::bind(format!(
"127.0.0.1:{}",
"0.0.0.0:{}",
std::env::args()
.nth(1)
.and_then(|s| s.parse::<u16>().ok())
@ -175,9 +176,6 @@ async fn oauth2_login(State(state): State<AppState>) -> Result<Redirect, StatusC
))
}
use time::Duration as TDuration;
use time::OffsetDateTime;
#[axum::debug_handler]
async fn oauth2_callback(
State(state): State<AppState>,
@ -197,18 +195,15 @@ async fn oauth2_callback(
.await
.wrap_err("callback")?;
let rep = state
.http
.get("https://api.intra.42.fr/v2/users/me")
.bearer_auth(&token.access_token)
.send()
let res: User42 = state
.oauth
.do_request("https://api.intra.42.fr/v2/me", &(), Some(&token))
.await
.wrap_err("Unable to get user self")?;
let json: User42 = rep.json().await.wrap_err("unable to parse api reply")?;
let mut cookie = Cookie::new("token", json.id.to_string());
let mut cookie = Cookie::new("token", res.id.to_string());
cookie.set_same_site(SameSite::None);
cookie.set_secure(true);
cookie.set_secure(false);
cookie.set_path("/");
// cookie.set_domain("localhost:3000");
// cookie.set_http_only(Some(false));
@ -225,23 +220,20 @@ async fn oauth2_callback(
}
#[derive(Clone, Debug)]
struct UserLoggedIn {
id: u64,
}
struct UserLoggedIn;
#[async_trait]
impl FromRequestParts<AppState> for UserLoggedIn {
type Rejection = (StatusCode, CookieJar, Redirect);
type Rejection = (StatusCode, PrivateCookieJar, Redirect);
async fn from_request_parts(
parts: &mut Parts,
state: &AppState,
) -> Result<Self, Self::Rejection> {
info!("banane");
let jar = CookieJar::from_request_parts(parts, state).await.unwrap();
dbg!(&jar);
let jar = PrivateCookieJar::from_request_parts(parts, state)
.await
.unwrap();
let Some(id) = jar.get("token") else {
info!("no cookie");
return Err((
StatusCode::TEMPORARY_REDIRECT,
jar,
@ -251,7 +243,6 @@ impl FromRequestParts<AppState> for UserLoggedIn {
let Ok(user_id) = id.value().parse::<u64>() else {
let jar = jar.remove("token");
info!("not id");
return Err((
StatusCode::TEMPORARY_REDIRECT,
jar,
@ -260,10 +251,8 @@ impl FromRequestParts<AppState> for UserLoggedIn {
};
if state.tutors.lock().await.contains(&user_id) {
info!("is tut");
Ok(UserLoggedIn { id: user_id })
Ok(UserLoggedIn)
} else {
info!("not tut");
let jar = jar.remove("token");
Err((
StatusCode::TEMPORARY_REDIRECT,
@ -283,9 +272,8 @@ async fn root(_user: UserLoggedIn) -> Html<&'static str> {
<a href="/stop">stop</a><br>
<a href="/start">start</a><br>
<a href="/status">status</a><br>
<a href="/db">db</a><br>
<a href="/pull">git pull (ask before!)</a><br>
"#,
"#,
)
}