auth/src/main.rs

210 lines
7.0 KiB
Rust

use axum::extract::Query;
use axum::http::{HeaderMap, HeaderValue, Uri};
use axum::response::{Html, Redirect};
use axum::{extract::State, http::StatusCode, response::IntoResponse, routing::get, Form, Router};
use minijinja::{context, Environment};
use serde::Deserialize;
use sqlx::sqlite::SqlitePool;
use std::net::SocketAddr;
use std::sync::Arc;
use std::time::{Duration, Instant};
use std::{borrow::BorrowMut, env};
use std::{collections::HashMap, str::FromStr};
use tokio::sync::Mutex;
use tower_cookies::{Cookie, CookieManagerLayer, Cookies};
use tower_http::trace::{self, TraceLayer};
use tracing::Level;
use uuid::Uuid;
pub struct ServerState {
pub db: sqlx::Pool<sqlx::Sqlite>,
pub session: Mutex<HashMap<Uuid, Instant>>,
// started: Instant,
}
#[derive(Deserialize, sqlx::FromRow, Debug)]
pub struct UserLoginForm {
#[sqlx(rename = "NAME")]
pub username: String,
#[sqlx(rename = "KEY")]
pub otp: String,
}
use once_cell::sync::Lazy;
static COOKIE_NAME: Lazy<String> =
Lazy::new(|| env::var("COOKIE_NAME").unwrap_or("aaron_auth".to_string()));
const SESSION_ACTIVE_TIME: Lazy<u64> = Lazy::new(|| {
env::var("SESSION_ACTIVE_TIME")
.ok()
.and_then(|value| value.parse().ok())
.unwrap_or(600)
});
const LOGIN_PAGE_HTML: &str = include_str!("../loginpage.html");
#[tokio::main]
async fn main() {
// 初始化日志记录器
tracing_subscriber::fmt::init();
let pool = SqlitePool::connect(&env::var("DATABASE_URL").expect("DB URL NOT SPECIFIED"))
.await
.expect("DB OPEN FAILURE"); // our router
let state = Arc::new(ServerState {
db: pool,
session: HashMap::new().into(),
// started: std::time::Instant::now(),
});
let app = Router::new()
.route("/auth", get(auth)) // http://127.0.0.1:3000
.route("/login", get(login_page).post(login))
.with_state(state.clone())
.layer(CookieManagerLayer::new())
.layer(
TraceLayer::new_for_http()
.make_span_with(trace::DefaultMakeSpan::new().level(Level::INFO))
.on_response(trace::DefaultOnResponse::new().level(Level::INFO)),
);
let port = env::var("PORT").unwrap_or("3000".to_string());
let port = port.parse::<u16>().unwrap_or(3000);
let addr = SocketAddr::from(([127, 0, 0, 1], port));
// run it with hyper on localhost:3000
tokio::spawn(gc_task(state.clone()));
axum::Server::bind(&addr)
.serve(app.into_make_service())
.await
.unwrap();
}
async fn gc_task(state: Arc<ServerState>) {
let mut interval = tokio::time::interval(Duration::from_secs(*SESSION_ACTIVE_TIME));
loop {
interval.tick().await;
let res = gc(state.clone()).await;
match res {
Ok(_) => tracing::info!("gc completed"),
Err(s) => tracing::error!("gc failed:{}", s),
}
}
}
// 处理/auth
async fn auth(
State(state): State<Arc<ServerState>>,
// Form(frm): Form<UserLoginForm>,
cookies: Cookies,
) -> StatusCode {
if let Some(session_token) = cookies.get(&COOKIE_NAME) {
tracing::info!("session:{}", session_token.value());
let Ok(s) = uuid::Uuid::from_str(session_token.value()) else {
return StatusCode::UNAUTHORIZED;
};
let mut locked = state.session.lock().await;
if locked.contains_key(&s) {
// FIX, when accessed /auth with correct cookie, the cookie's expiration is delayed
let Some(v) = locked.insert(
s,
Instant::now() + Duration::from_secs(*SESSION_ACTIVE_TIME),
) else {
tracing::info!("session:{} extended", session_token.value());
return StatusCode::UNAUTHORIZED;
};
if Instant::now() < v {
return StatusCode::OK;
}
}
}
return StatusCode::UNAUTHORIZED;
}
async fn login(
State(state): State<Arc<ServerState>>,
cookies: Cookies,
Query(mut params): Query<HashMap<String, String>>,
Form(frm): Form<UserLoginForm>,
) -> Result<Redirect, (StatusCode, &'static str)> {
let conn = state.db.acquire().await;
let Ok(mut conn) = conn else {
return Err((StatusCode::BAD_GATEWAY, "db连接错误"));
};
tracing::info!("{:?}", &frm);
let target = sqlx::query_as::<_, UserLoginForm>(
r#"
SELECT NAME, KEY FROM USERS WHERE NAME = ?
"#,
)
.bind(frm.username)
.fetch_optional(&mut *conn)
.await;
tracing::info!("{:?}", &target);
if let Ok(Some(target)) = target {
if check_otp(target.otp, frm.otp) {
let s = Uuid::new_v4();
let mut locked = state.session.lock().await;
locked.insert(
s.clone(),
Instant::now() + Duration::from_secs(*SESSION_ACTIVE_TIME),
);
let mut new_cookie = Cookie::new(&*COOKIE_NAME, s.to_string());
new_cookie.set_domain(".aaronhu.cn");
cookies.add(new_cookie);
if let Some(original_uri) = params.get("original_url") {
return Ok(Redirect::to(&original_uri));
}
return Err((StatusCode::ACCEPTED, "ok"));
} else {
return Err((StatusCode::UNAUTHORIZED, "wrong password"));
}
}
return Err((StatusCode::BAD_GATEWAY, "unreachable"));
}
async fn login_page(headers: HeaderMap<HeaderValue>) -> impl IntoResponse {
tracing::info!("Headers: {:#?}", headers);
let mut env = Environment::new();
env.add_template("login.html", LOGIN_PAGE_HTML).unwrap();
let template = env.get_template("login.html").unwrap();
if let Some(original_uri) = headers.get("X-Original-URI") {
if let Ok(uri) = original_uri.to_str() {
tracing::info!("redirect to {}", uri);
if !uri.is_empty() {
let uri = "?original_url=".to_owned() + uri;
return Html(
template
.render(context! { url => uri })
.unwrap_or("Error".to_string()),
);
}
}
}
return Html(
template
.render(context! { url => String::new() })
.unwrap_or("Error".to_string()),
);
}
pub fn check_otp(key_from_db: String, user_input_otp: String) -> bool {
use totp_rs::{Algorithm, Secret, TOTP};
let totp = TOTP::new(
Algorithm::SHA1,
6,
1,
30,
Secret::Raw(key_from_db.as_bytes().to_vec())
.to_bytes()
.unwrap(),
);
if let Ok(otp) = totp {
if let Ok(token) = otp.generate_current() {
return token == user_input_otp;
}
}
return false;
}
async fn gc(state: Arc<ServerState>) -> Result<(), String> {
let mut locked = state.session.lock().await;
let current_time = Instant::now();
tracing::info!("before gc ,active Sessions {:?}", locked);
locked.borrow_mut().retain(|_, v| *v > current_time);
tracing::info!("gc fired,active Sessions {:?}", locked);
Ok(())
}