diff --git a/Cargo.lock b/Cargo.lock index aeeec3d..b876c7e 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -124,6 +124,7 @@ dependencies = [ "askama", "axum", "minijinja", + "once_cell", "serde", "sqlx", "tokio", diff --git a/Cargo.toml b/Cargo.toml index 2fc123f..da1e6ce 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -18,3 +18,4 @@ tower-http = { version = "0.4.4", features = ["trace"] } tracing = "0.1.40" askama = "0.10" minijinja = "1.0.9" +once_cell = "1.18.0" diff --git a/src/main.rs b/src/main.rs index fc6c556..72ac862 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,40 +1,41 @@ use axum::extract::Query; +use axum::http::{HeaderMap, HeaderValue, Uri}; use axum::response::{Html, Redirect}; -use axum::http::{Uri, HeaderMap, HeaderValue}; -use axum::{ - extract::State, - http::StatusCode, - response::IntoResponse, - routing::get, - Form, Router, -}; +use axum::{extract::State, http::StatusCode, response::IntoResponse, routing::get, Form, Router}; +use minijinja::{context, Environment}; use serde::Deserialize; use sqlx::sqlite::SqlitePool; -use std::{env, borrow::BorrowMut}; use std::net::SocketAddr; use std::sync::Arc; use std::time::{Duration, Instant}; +use std::{borrow::BorrowMut, env}; use std::{collections::HashMap, str::FromStr}; use tokio::sync::Mutex; use tower_cookies::{Cookie, CookieManagerLayer, Cookies}; -use uuid::Uuid; use tower_http::trace::{self, TraceLayer}; use tracing::Level; -use minijinja::{Environment, context}; +use uuid::Uuid; pub struct ServerState { pub db: sqlx::Pool, pub session: Mutex>, // started: Instant, } -#[derive(Deserialize, sqlx::FromRow,Debug)] +#[derive(Deserialize, sqlx::FromRow, Debug)] pub struct UserLoginForm { #[sqlx(rename = "NAME")] pub username: String, #[sqlx(rename = "KEY")] pub otp: String, } -const COOKIE_NAME: &str = "aaron_auth"; -const SESSION_ACTIVE_TIME: u64 = 600; +use once_cell::sync::Lazy; +static COOKIE_NAME: Lazy = + Lazy::new(|| env::var("COOKIE_NAME").unwrap_or("aaron_auth".to_string())); +const SESSION_ACTIVE_TIME: Lazy = Lazy::new(|| { + env::var("SESSION_ACTIVE_TIME") + .ok() + .and_then(|value| value.parse().ok()) + .unwrap_or(600) +}); const LOGIN_PAGE_HTML: &str = include_str!("../loginpage.html"); #[tokio::main] async fn main() { @@ -69,15 +70,14 @@ async fn main() { .unwrap(); } - -async fn gc_task(state:Arc){ - let mut interval = tokio::time::interval(Duration::from_secs(SESSION_ACTIVE_TIME)); +async fn gc_task(state: Arc) { + let mut interval = tokio::time::interval(Duration::from_secs(*SESSION_ACTIVE_TIME)); loop { interval.tick().await; let res = gc(state.clone()).await; match res { Ok(_) => tracing::info!("gc completed"), - Err(s) => tracing::error!("gc failed:{}",s), + Err(s) => tracing::error!("gc failed:{}", s), } } } @@ -88,15 +88,19 @@ async fn auth( // Form(frm): Form, cookies: Cookies, ) -> StatusCode { - if let Some(session_token) = cookies.get(COOKIE_NAME) { - tracing::info!("session:{}",session_token.value()); + if let Some(session_token) = cookies.get(&COOKIE_NAME) { + tracing::info!("session:{}", session_token.value()); let Ok(s) = uuid::Uuid::from_str(session_token.value()) else { return StatusCode::UNAUTHORIZED; }; let mut locked = state.session.lock().await; - if locked.contains_key(&s) { // FIX, when accessed /auth with correct cookie, the cookie's expiration is delayed - let Some(v) = locked.insert(s,Instant::now()+Duration::from_secs(SESSION_ACTIVE_TIME)) else { - tracing::info!("session:{} extended",session_token.value()); + if locked.contains_key(&s) { + // FIX, when accessed /auth with correct cookie, the cookie's expiration is delayed + let Some(v) = locked.insert( + s, + Instant::now() + Duration::from_secs(*SESSION_ACTIVE_TIME), + ) else { + tracing::info!("session:{} extended", session_token.value()); return StatusCode::UNAUTHORIZED; }; if Instant::now() < v { @@ -117,7 +121,7 @@ async fn login( let Ok(mut conn) = conn else { return Err((StatusCode::BAD_GATEWAY, "db连接错误")); }; - tracing::info!("{:?}",&frm); + tracing::info!("{:?}", &frm); let target = sqlx::query_as::<_, UserLoginForm>( r#" SELECT NAME, KEY FROM USERS WHERE NAME = ? @@ -126,14 +130,17 @@ async fn login( .bind(frm.username) .fetch_optional(&mut *conn) .await; - tracing::info!("{:?}",&target); + tracing::info!("{:?}", &target); if let Ok(Some(target)) = target { if check_otp(target.otp, frm.otp) { let s = Uuid::new_v4(); let mut locked = state.session.lock().await; - locked.insert(s.clone(), Instant::now() + Duration::from_secs(SESSION_ACTIVE_TIME)); - let mut new_cookie = Cookie::new(COOKIE_NAME, s.to_string()); + locked.insert( + s.clone(), + Instant::now() + Duration::from_secs(*SESSION_ACTIVE_TIME), + ); + let mut new_cookie = Cookie::new(&*COOKIE_NAME, s.to_string()); new_cookie.set_domain(".aaronhu.cn"); cookies.add(new_cookie); if let Some(original_uri) = params.get("original_url") { @@ -142,29 +149,35 @@ async fn login( return Err((StatusCode::ACCEPTED, "ok")); } else { - return Err((StatusCode::UNAUTHORIZED,"wrong password")); + return Err((StatusCode::UNAUTHORIZED, "wrong password")); } } return Err((StatusCode::BAD_GATEWAY, "unreachable")); } -async fn login_page( - headers: HeaderMap, -)-> impl IntoResponse{ - tracing::info!("Headers: {:#?}",headers); +async fn login_page(headers: HeaderMap) -> impl IntoResponse { + tracing::info!("Headers: {:#?}", headers); let mut env = Environment::new(); env.add_template("login.html", LOGIN_PAGE_HTML).unwrap(); let template = env.get_template("login.html").unwrap(); if let Some(original_uri) = headers.get("X-Original-URI") { if let Ok(uri) = original_uri.to_str() { - tracing::info!("redirect to {}",uri); + tracing::info!("redirect to {}", uri); if !uri.is_empty() { - let uri = "?original_url=".to_owned()+uri; - return Html(template.render(context! { url => uri }).unwrap_or("Error".to_string())) + let uri = "?original_url=".to_owned() + uri; + return Html( + template + .render(context! { url => uri }) + .unwrap_or("Error".to_string()), + ); } } } - return Html(template.render(context! { url => String::new() }).unwrap_or("Error".to_string())) + return Html( + template + .render(context! { url => String::new() }) + .unwrap_or("Error".to_string()), + ); } pub fn check_otp(key_from_db: String, user_input_otp: String) -> bool { @@ -186,11 +199,11 @@ pub fn check_otp(key_from_db: String, user_input_otp: String) -> bool { return false; } -async fn gc(state:Arc)->Result<(),String>{ +async fn gc(state: Arc) -> Result<(), String> { let mut locked = state.session.lock().await; let current_time = Instant::now(); - tracing::info!("before gc ,active Sessions {:?}",locked); - locked.borrow_mut().retain(|_,v| *v > current_time); - tracing::info!("gc fired,active Sessions {:?}",locked); + tracing::info!("before gc ,active Sessions {:?}", locked); + locked.borrow_mut().retain(|_, v| *v > current_time); + tracing::info!("gc fired,active Sessions {:?}", locked); Ok(()) }