Refactor auth a bit

This commit is contained in:
soup 2024-11-15 14:09:29 -05:00
parent 921bf66d20
commit 46bf383ad2
4 changed files with 42 additions and 38 deletions

View file

@ -71,3 +71,22 @@ func queryDeleteSession(db *sql.DB, sessionId string) error {
func SessionDelete(db *sql.DB, sessionId string) error { func SessionDelete(db *sql.DB, sessionId string) error {
return queryDeleteSession(db, sessionId) return queryDeleteSession(db, sessionId)
} }
type SessionInfo struct {
IsAdmin bool
SessionId string
}
func queryCheckAuth(db *sql.DB, sessionId string) (SessionInfo, error) {
sessionCount := 0
err := db.QueryRow("select count(*) from session where session_id = ?", sessionId).Scan(&sessionCount)
if err != nil {
return SessionInfo{}, err
}
return SessionInfo{IsAdmin: sessionCount == 1, SessionId: sessionId}, nil
}
func SessionCheck(db *sql.DB, sessionId string) (SessionInfo, error) {
return queryCheckAuth(db, sessionId)
}

View file

@ -5,6 +5,7 @@ import (
"database/sql" "database/sql"
"log" "log"
"net/http" "net/http"
"shelves/backend/auth"
"time" "time"
) )
@ -31,34 +32,6 @@ func Log(next http.Handler) http.Handler {
}) })
} }
type AuthInfo struct {
IsAdmin bool
SessionId string
}
func queryCheckAuth(db *sql.DB, sessionId string) (AuthInfo, error) {
sessionCount := 0
err := db.QueryRow("select count(*) from session where session_id = ?", sessionId).Scan(&sessionCount)
if err != nil {
return AuthInfo{}, err
}
return AuthInfo{IsAdmin: sessionCount == 1, SessionId: sessionId}, nil
}
func checkAuthed(db *sql.DB, r *http.Request) (AuthInfo, error) {
cookie, err := r.Cookie("SHELVES_OWNER_SESSION_ID")
if err == http.ErrNoCookie {
return AuthInfo{}, nil
}
if err != nil {
return AuthInfo{}, err
}
sessionId := cookie.Value
return queryCheckAuth(db, sessionId)
}
func queryNeedsOwnerSetup(db *sql.DB) (bool, error) { func queryNeedsOwnerSetup(db *sql.DB) (bool, error) {
count := 0 count := 0
err := db.QueryRow("select count(*) from owner_settings").Scan(&count) err := db.QueryRow("select count(*) from owner_settings").Scan(&count)
@ -69,9 +42,21 @@ func queryNeedsOwnerSetup(db *sql.DB) (bool, error) {
return count != 1, nil return count != 1, nil
} }
func sessionInfo(db *sql.DB, r *http.Request) (auth.SessionInfo, error) {
cookie, err := r.Cookie("SHELVES_OWNER_SESSION_ID")
if err == http.ErrNoCookie {
return auth.SessionInfo{}, nil
}
if err != nil {
return auth.SessionInfo{}, err
}
return auth.SessionCheck(db, cookie.Value)
}
type Ctx struct { type Ctx struct {
DB *sql.DB DB *sql.DB
Auth AuthInfo Session auth.SessionInfo
NeedsOwnerSetup bool NeedsOwnerSetup bool
} }
@ -89,13 +74,13 @@ func WithCtx(db *sql.DB, next http.Handler) http.Handler {
} }
ctx := r.Context() ctx := r.Context()
auth, err := checkAuthed(db, r) session, err := sessionInfo(db, r)
if err != nil { if err != nil {
log.Printf("Error while querying auth info: %v\n", err) log.Printf("Error while querying session info: %v\n", err)
} }
ctx = context.WithValue(ctx, "__ctx", Ctx{ ctx = context.WithValue(ctx, "__ctx", Ctx{
DB: db, DB: db,
Auth: auth, Session: session,
NeedsOwnerSetup: needsOwnerSetup, NeedsOwnerSetup: needsOwnerSetup,
}) })
r = r.WithContext(ctx) r = r.WithContext(ctx)

View file

@ -61,7 +61,7 @@ func LoginGet(w http.ResponseWriter, r *http.Request) {
ctx := httpx.GetCtx(r) ctx := httpx.GetCtx(r)
redirectTo := getRedirectTo(r) redirectTo := getRedirectTo(r)
if ctx.Auth.IsAdmin { if ctx.Session.IsAdmin {
httpx.SeeOther(w, redirectTo) httpx.SeeOther(w, redirectTo)
return return
} }
@ -88,7 +88,7 @@ func LoginPost(w http.ResponseWriter, r *http.Request) {
return return
} }
if ctx.Auth.IsAdmin { if ctx.Session.IsAdmin {
httpx.SeeOther(w, form.redirectTo) httpx.SeeOther(w, form.redirectTo)
return return
} }
@ -118,12 +118,12 @@ func LoginPost(w http.ResponseWriter, r *http.Request) {
func LoginDelete(w http.ResponseWriter, r *http.Request) { func LoginDelete(w http.ResponseWriter, r *http.Request) {
ctx := httpx.GetCtx(r) ctx := httpx.GetCtx(r)
if !ctx.Auth.IsAdmin { if !ctx.Session.IsAdmin {
httpx.SeeOther(w, "/") httpx.SeeOther(w, "/")
return return
} }
sessionId := ctx.Auth.SessionId sessionId := ctx.Session.SessionId
err := auth.SessionDelete(ctx.DB, sessionId) err := auth.SessionDelete(ctx.DB, sessionId)
if err != nil { if err != nil {
httpx.InternalServerError(w, err) httpx.InternalServerError(w, err)

View file

@ -55,7 +55,7 @@ func queryGetOwnerSettings(db *sql.DB) (setupForm, error) {
func SetupGet(w http.ResponseWriter, r *http.Request) { func SetupGet(w http.ResponseWriter, r *http.Request) {
ctx := httpx.GetCtx(r) ctx := httpx.GetCtx(r)
if !ctx.Auth.IsAdmin && !ctx.NeedsOwnerSetup { if !ctx.Session.IsAdmin && !ctx.NeedsOwnerSetup {
httpx.LoginRedirect(w, *r.URL) httpx.LoginRedirect(w, *r.URL)
return return
} }
@ -115,7 +115,7 @@ func updateOwnerSettings(db *sql.DB, f setupForm) error {
func SetupPost(w http.ResponseWriter, r *http.Request) { func SetupPost(w http.ResponseWriter, r *http.Request) {
ctx := httpx.GetCtx(r) ctx := httpx.GetCtx(r)
if !ctx.Auth.IsAdmin && !ctx.NeedsOwnerSetup { if !ctx.Session.IsAdmin && !ctx.NeedsOwnerSetup {
httpx.Unauthorized(w) httpx.Unauthorized(w)
return return
} }