This commit is contained in:
yly 2023-10-30 16:38:48 +08:00
commit a9fd136a55
7 changed files with 2386 additions and 0 deletions

1
.gitignore vendored Normal file
View File

@ -0,0 +1 @@
/target

View File

@ -0,0 +1,26 @@
{
"db_name": "SQLite",
"query": "\nSELECT NAME, KEY FROM USERS WHERE NAME = ?1 AND KEY = ?2\n ",
"describe": {
"columns": [
{
"name": "NAME",
"ordinal": 0,
"type_info": "Text"
},
{
"name": "KEY",
"ordinal": 1,
"type_info": "Text"
}
],
"parameters": {
"Right": 2
},
"nullable": [
false,
false
]
},
"hash": "6b379715355f3443d66111020be5c3c7ec2a57c9789d04485b998fd934d6b0a8"
}

2104
Cargo.lock generated Normal file

File diff suppressed because it is too large Load Diff

18
Cargo.toml Normal file
View File

@ -0,0 +1,18 @@
[package]
name = "auth"
version = "0.1.0"
edition = "2021"
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
[dependencies]
axum = {version = "0.6.20", features = [ "default", "headers" ]}
tokio = { version = "1.0", features = ["full"] }
sqlx = { version = "0.7", features = ["runtime-tokio", "sqlite"] }
tower-cookies = "0.9.0"
serde = "1.0.190"
uuid = { version = "1.5.0", features = ["v4", "fast-rng"] }
totp-rs = "5.4.0"
tracing-subscriber = { version = "0.3.17", features = ["env-filter"] }
tower-http = { version = "0.4.4", features = ["trace"] }
tracing = "0.1.40"

BIN
auth.db Normal file

Binary file not shown.

67
loginpage.html Normal file
View File

@ -0,0 +1,67 @@
<!DOCTYPE html>
<html>
<head>
<title>Login</title>
<style>
body {
font-family: Arial, sans-serif;
background-color: #f2f2f2;
}
.container {
max-width: 400px;
margin: 0 auto;
padding: 20px;
background-color: #fff;
box-shadow: 0 0 10px rgba(0, 0, 0, 0.1);
}
h1 {
text-align: center;
}
label {
display: block;
margin-top: 10px;
}
input[type="text"],
input[type="password"] {
width: 100%;
padding: 10px;
border: 1px solid #ccc;
border-radius: 4px;
box-sizing: border-box;
}
input[type="submit"] {
width: 100%;
padding: 10px;
margin-top: 20px;
background-color: #4CAF50;
color: #fff;
border: none;
border-radius: 4px;
cursor: pointer;
}
input[type="submit"]:hover {
background-color: #45a049;
}
</style>
</head>
<body>
<div class="container">
<h1>Login</h1>
<form action="/login" method="POST">
<label for="username">Username:</label>
<input type="text" id="username" name="username" required>
<label for="password">Password:</label>
<input type="password" id="password" name="otp" required>
<input type="submit" value="Submit">
</form>
</div>
</body>
</html>

170
src/main.rs Normal file
View File

@ -0,0 +1,170 @@
use axum::{Extension, TypedHeader};
use axum::response::{Html, Redirect};
use axum::http::{Uri, HeaderMap, HeaderValue};
use axum::{
extract::State,
http::StatusCode,
response::IntoResponse,
routing::get,
Form, Router,
};
use serde::Deserialize;
use sqlx::sqlite::SqlitePool;
use std::{env, borrow::BorrowMut};
use std::net::SocketAddr;
use std::sync::Arc;
use std::time::{Duration, Instant};
use std::{collections::HashMap, str::FromStr};
use tokio::sync::Mutex;
use tower_cookies::{Cookie, CookieManagerLayer, Cookies};
use uuid::Uuid;
use tower_http::trace::{self, TraceLayer};
use tracing::Level;
pub struct ServerState {
pub db: sqlx::Pool<sqlx::Sqlite>,
pub session: Mutex<HashMap<Uuid, Instant>>,
// started: Instant,
}
#[derive(Deserialize, sqlx::FromRow,Debug)]
pub struct UserLoginForm {
#[sqlx(rename = "NAME")]
pub username: String,
#[sqlx(rename = "KEY")]
pub otp: String,
}
const COOKIE_NAME: &str = "aaron_auth";
const SESSION_ACTIVE_TIME: u64 = 600;
const LOGIN_PAGE_HTML: &str = include_str!("../loginpage.html");
#[tokio::main]
async fn main() {
// 初始化日志记录器
tracing_subscriber::fmt::init();
let pool = SqlitePool::connect(&env::var("DATABASE_URL").expect("DB URL NOT SPECIFIED"))
.await
.expect("DB OPEN FAILURE"); // our router
let state = Arc::new(ServerState {
db: pool,
session: HashMap::new().into(),
// started: std::time::Instant::now(),
});
let app = Router::new()
.route("/auth", get(auth)) // http://127.0.0.1:3000
.route("/login", get(login_page).post(login))
.with_state(state.clone())
.layer(CookieManagerLayer::new())
.layer(
TraceLayer::new_for_http()
.make_span_with(trace::DefaultMakeSpan::new().level(Level::INFO))
.on_response(trace::DefaultOnResponse::new().level(Level::INFO)),
);
let port = env::var("PORT").unwrap_or("3000".to_string());
let port = port.parse::<u16>().unwrap_or(3000);
let addr = SocketAddr::from(([127, 0, 0, 1], port));
// run it with hyper on localhost:3000
tokio::spawn(gc_task(state.clone()));
axum::Server::bind(&addr)
.serve(app.into_make_service())
.await
.unwrap();
}
async fn gc_task(state:Arc<ServerState>){
let mut interval = tokio::time::interval(Duration::from_secs(SESSION_ACTIVE_TIME));
loop {
interval.tick().await;
let res = gc(state.clone()).await;
match res {
Ok(_) => tracing::info!("gc completed"),
Err(s) => tracing::error!("gc failed:{}",s),
}
}
}
// 处理/auth
async fn auth(
State(state): State<Arc<ServerState>>,
// Form(frm): Form<UserLoginForm>,
cookies: Cookies,
) -> StatusCode {
if let Some(session_token) = cookies.get(COOKIE_NAME) {
let Ok(s) = uuid::Uuid::from_str(session_token.value()) else {
return StatusCode::UNAUTHORIZED;
};
let mut locked = state.session.lock().await;
let Some(v) = locked.insert(s,Instant::now()+Duration::from_secs(SESSION_ACTIVE_TIME)) else {
return StatusCode::UNAUTHORIZED;
};
if Instant::now() < v {
return StatusCode::OK;
}
}
return StatusCode::UNAUTHORIZED;
}
async fn login(
State(state): State<Arc<ServerState>>,
cookies: Cookies,
headers: HeaderMap<HeaderValue>,
Form(frm): Form<UserLoginForm>,
) -> impl IntoResponse {
let conn = state.db.acquire().await;
let Ok(mut conn) = conn else {
return (StatusCode::BAD_GATEWAY, "db连接错误");
};
tracing::info!("{:?}",&frm);
let target = sqlx::query_as::<_, UserLoginForm>(
r#"
SELECT NAME, KEY FROM USERS WHERE NAME = ?
"#,
)
.bind(frm.username)
.fetch_optional(&mut *conn)
.await;
tracing::info!("{:?}",&target);
if let Ok(Some(target)) = target {
if check_otp(target.otp, frm.otp) {
let s = Uuid::new_v4();
let mut locked = state.session.lock().await;
locked.insert(s.clone(), Instant::now() + Duration::from_secs(SESSION_ACTIVE_TIME));
cookies.add(Cookie::new(COOKIE_NAME, s.to_string()));
return (StatusCode::ACCEPTED, "ok");
} else {
return (StatusCode::UNAUTHORIZED,"wrong password");
}
}
return (StatusCode::BAD_GATEWAY, "unreachable");
}
async fn login_page(
)-> impl IntoResponse{
return Html(LOGIN_PAGE_HTML);
}
pub fn check_otp(key_from_db: String, user_input_otp: String) -> bool {
use totp_rs::{Algorithm, Secret, TOTP};
let totp = TOTP::new(
Algorithm::SHA1,
6,
1,
30,
Secret::Raw(key_from_db.as_bytes().to_vec())
.to_bytes()
.unwrap(),
);
if let Ok(otp) = totp {
if let Ok(token) = otp.generate_current() {
return token == user_input_otp;
}
}
return false;
}
async fn gc(state:Arc<ServerState>)->Result<(),String>{
let mut locked = state.session.lock().await;
let current_time = Instant::now();
locked.borrow_mut().retain(|_,v| *v < current_time);
tracing::info!("gc fired,active Sessions {:?}",locked);
Ok(())
}