use axum::extract::Query; use axum::http::{HeaderMap, HeaderValue}; use axum::response::{Html, Redirect}; use axum::{extract::State, http::StatusCode, response::IntoResponse, routing::get, Form, Router}; use minijinja::{context, Environment}; use serde::Deserialize; use sqlx::sqlite::SqlitePool; use std::net::SocketAddr; use std::sync::Arc; use std::time::{Duration, Instant}; use std::{borrow::BorrowMut, env}; use std::{collections::HashMap, str::FromStr}; use tokio::sync::Mutex; use tower_cookies::{Cookie, CookieManagerLayer, Cookies}; use tower_http::trace::{self, TraceLayer}; use tracing::Level; use uuid::Uuid; pub struct ServerState { pub db: sqlx::Pool, pub session: Mutex>, // started: Instant, } #[derive(Deserialize, sqlx::FromRow, Debug)] pub struct UserLoginForm { #[sqlx(rename = "NAME")] pub username: String, #[sqlx(rename = "KEY")] pub otp: String, } use once_cell::sync::Lazy; static COOKIE_NAME: Lazy = Lazy::new(|| env::var("COOKIE_NAME").unwrap_or("aaron_auth".to_string())); const SESSION_ACTIVE_TIME: Lazy = Lazy::new(|| { env::var("SESSION_ACTIVE_TIME") .ok() .and_then(|value| value.parse().ok()) .unwrap_or(600) }); const LOGIN_PAGE_HTML: &str = include_str!("../loginpage.html"); #[tokio::main] async fn main() { // 初始化日志记录器 tracing_subscriber::fmt::init(); let pool = SqlitePool::connect(&env::var("DATABASE_URL").expect("DB URL NOT SPECIFIED")) .await .expect("DB OPEN FAILURE"); // our router let state = Arc::new(ServerState { db: pool, session: HashMap::new().into(), // started: std::time::Instant::now(), }); let app = Router::new() .route("/auth", get(auth)) // http://127.0.0.1:3000 .route("/login", get(login_page).post(login)) .with_state(state.clone()) .layer(CookieManagerLayer::new()) .layer( TraceLayer::new_for_http() .make_span_with(trace::DefaultMakeSpan::new().level(Level::INFO)) .on_response(trace::DefaultOnResponse::new().level(Level::INFO)), ); let port = env::var("PORT").unwrap_or("3000".to_string()); let port = port.parse::().unwrap_or(3000); let addr = SocketAddr::from(([127, 0, 0, 1], port)); // run it with hyper on localhost:3000 tokio::spawn(gc_task(state.clone())); axum::Server::bind(&addr) .serve(app.into_make_service()) .await .unwrap(); } async fn gc_task(state: Arc) { let mut interval = tokio::time::interval(Duration::from_secs(*SESSION_ACTIVE_TIME)); loop { interval.tick().await; let res = gc(state.clone()).await; match res { Ok(_) => tracing::info!("gc completed"), Err(s) => tracing::error!("gc failed:{}", s), } } } // 处理/auth async fn auth( State(state): State>, // Form(frm): Form, cookies: Cookies, ) -> StatusCode { if let Some(session_token) = cookies.get(&COOKIE_NAME) { tracing::info!("session:{}", session_token.value()); let Ok(s) = uuid::Uuid::from_str(session_token.value()) else { return StatusCode::UNAUTHORIZED; }; let mut locked = state.session.lock().await; if let std::collections::hash_map::Entry::Occupied(mut e) = locked.entry(s) { // FIX, when accessed /auth with correct cookie, the cookie's expiration is delayed let Some(v) = Some(e.insert(Instant::now() + Duration::from_secs(*SESSION_ACTIVE_TIME))) else { tracing::info!("session:{} extended", session_token.value()); return StatusCode::UNAUTHORIZED; }; if Instant::now() < v { return StatusCode::OK; } } } StatusCode::UNAUTHORIZED } async fn login( State(state): State>, cookies: Cookies, Query(params): Query>, Form(frm): Form, ) -> Result { let conn = state.db.acquire().await; let Ok(mut conn) = conn else { return Err((StatusCode::BAD_GATEWAY, "db连接错误")); }; tracing::info!("{:?}", &frm); let target = sqlx::query_as::<_, UserLoginForm>( r#" SELECT NAME, KEY FROM USERS WHERE NAME = ? "#, ) .bind(frm.username) .fetch_optional(&mut *conn) .await; tracing::info!("{:?}", &target); if let Ok(Some(target)) = target { if check_otp(target.otp, frm.otp) { let s = Uuid::new_v4(); let mut locked = state.session.lock().await; locked.insert( s, Instant::now() + Duration::from_secs(*SESSION_ACTIVE_TIME), ); let mut new_cookie = Cookie::new(&*COOKIE_NAME, s.to_string()); new_cookie.set_domain(".aaronhu.cn"); cookies.add(new_cookie); if let Some(original_uri) = params.get("original_url") { return Ok(Redirect::to(original_uri)); } return Err((StatusCode::ACCEPTED, "ok")); } else { return Err((StatusCode::UNAUTHORIZED, "wrong password")); } } Err((StatusCode::BAD_GATEWAY, "unreachable")) } async fn login_page(headers: HeaderMap) -> impl IntoResponse { tracing::info!("Headers: {:#?}", headers); let mut env = Environment::new(); env.add_template("login.html", LOGIN_PAGE_HTML).unwrap(); let template = env.get_template("login.html").unwrap(); if let Some(original_uri) = headers.get("X-Original-URI") { if let Ok(uri) = original_uri.to_str() { tracing::info!("redirect to {}", uri); if !uri.is_empty() { let uri = "?original_url=".to_owned() + uri; return Html( template .render(context! { url => uri }) .unwrap_or("Error".to_string()), ); } } } Html( template .render(context! { url => String::new() }) .unwrap_or("Error".to_string()), ) } pub fn check_otp(key_from_db: String, user_input_otp: String) -> bool { use totp_rs::{Algorithm, Secret, TOTP}; let totp = TOTP::new( Algorithm::SHA1, 6, 1, 30, Secret::Raw(key_from_db.as_bytes().to_vec()) .to_bytes() .unwrap(), ); if let Ok(otp) = totp { if let Ok(token) = otp.generate_current() { return token == user_input_otp; } } false } async fn gc(state: Arc) -> Result<(), String> { let mut locked = state.session.lock().await; let current_time = Instant::now(); tracing::info!("before gc ,active Sessions {:?}", locked); locked.borrow_mut().retain(|_, v| *v > current_time); tracing::info!("gc fired,active Sessions {:?}", locked); Ok(()) }