diff --git a/backend/auth/auth.go b/backend/auth/auth.go index f9503d8..61891fe 100644 --- a/backend/auth/auth.go +++ b/backend/auth/auth.go @@ -71,3 +71,22 @@ func queryDeleteSession(db *sql.DB, sessionId string) error { func SessionDelete(db *sql.DB, sessionId string) error { return queryDeleteSession(db, sessionId) } + +type SessionInfo struct { + IsAdmin bool + SessionId string +} + +func queryCheckAuth(db *sql.DB, sessionId string) (SessionInfo, error) { + sessionCount := 0 + err := db.QueryRow("select count(*) from session where session_id = ?", sessionId).Scan(&sessionCount) + if err != nil { + return SessionInfo{}, err + } + + return SessionInfo{IsAdmin: sessionCount == 1, SessionId: sessionId}, nil +} + +func SessionCheck(db *sql.DB, sessionId string) (SessionInfo, error) { + return queryCheckAuth(db, sessionId) +} diff --git a/backend/httpx/middleware.go b/backend/httpx/middleware.go index 2fb9d1e..4e63fe7 100644 --- a/backend/httpx/middleware.go +++ b/backend/httpx/middleware.go @@ -5,6 +5,7 @@ import ( "database/sql" "log" "net/http" + "shelves/backend/auth" "time" ) @@ -31,34 +32,6 @@ func Log(next http.Handler) http.Handler { }) } -type AuthInfo struct { - IsAdmin bool - SessionId string -} - -func queryCheckAuth(db *sql.DB, sessionId string) (AuthInfo, error) { - sessionCount := 0 - err := db.QueryRow("select count(*) from session where session_id = ?", sessionId).Scan(&sessionCount) - if err != nil { - return AuthInfo{}, err - } - - return AuthInfo{IsAdmin: sessionCount == 1, SessionId: sessionId}, nil -} - -func checkAuthed(db *sql.DB, r *http.Request) (AuthInfo, error) { - cookie, err := r.Cookie("SHELVES_OWNER_SESSION_ID") - if err == http.ErrNoCookie { - return AuthInfo{}, nil - } - if err != nil { - return AuthInfo{}, err - } - - sessionId := cookie.Value - return queryCheckAuth(db, sessionId) -} - func queryNeedsOwnerSetup(db *sql.DB) (bool, error) { count := 0 err := db.QueryRow("select count(*) from owner_settings").Scan(&count) @@ -69,9 +42,21 @@ func queryNeedsOwnerSetup(db *sql.DB) (bool, error) { return count != 1, nil } +func sessionInfo(db *sql.DB, r *http.Request) (auth.SessionInfo, error) { + cookie, err := r.Cookie("SHELVES_OWNER_SESSION_ID") + if err == http.ErrNoCookie { + return auth.SessionInfo{}, nil + } + if err != nil { + return auth.SessionInfo{}, err + } + + return auth.SessionCheck(db, cookie.Value) +} + type Ctx struct { DB *sql.DB - Auth AuthInfo + Session auth.SessionInfo NeedsOwnerSetup bool } @@ -89,13 +74,13 @@ func WithCtx(db *sql.DB, next http.Handler) http.Handler { } ctx := r.Context() - auth, err := checkAuthed(db, r) + session, err := sessionInfo(db, r) if err != nil { - log.Printf("Error while querying auth info: %v\n", err) + log.Printf("Error while querying session info: %v\n", err) } ctx = context.WithValue(ctx, "__ctx", Ctx{ DB: db, - Auth: auth, + Session: session, NeedsOwnerSetup: needsOwnerSetup, }) r = r.WithContext(ctx) diff --git a/backend/routes/login.go b/backend/routes/login.go index a0ecac4..37daad8 100644 --- a/backend/routes/login.go +++ b/backend/routes/login.go @@ -61,7 +61,7 @@ func LoginGet(w http.ResponseWriter, r *http.Request) { ctx := httpx.GetCtx(r) redirectTo := getRedirectTo(r) - if ctx.Auth.IsAdmin { + if ctx.Session.IsAdmin { httpx.SeeOther(w, redirectTo) return } @@ -88,7 +88,7 @@ func LoginPost(w http.ResponseWriter, r *http.Request) { return } - if ctx.Auth.IsAdmin { + if ctx.Session.IsAdmin { httpx.SeeOther(w, form.redirectTo) return } @@ -118,12 +118,12 @@ func LoginPost(w http.ResponseWriter, r *http.Request) { func LoginDelete(w http.ResponseWriter, r *http.Request) { ctx := httpx.GetCtx(r) - if !ctx.Auth.IsAdmin { + if !ctx.Session.IsAdmin { httpx.SeeOther(w, "/") return } - sessionId := ctx.Auth.SessionId + sessionId := ctx.Session.SessionId err := auth.SessionDelete(ctx.DB, sessionId) if err != nil { httpx.InternalServerError(w, err) diff --git a/backend/routes/setup.go b/backend/routes/setup.go index c8e4add..c0a57b1 100644 --- a/backend/routes/setup.go +++ b/backend/routes/setup.go @@ -55,7 +55,7 @@ func queryGetOwnerSettings(db *sql.DB) (setupForm, error) { func SetupGet(w http.ResponseWriter, r *http.Request) { ctx := httpx.GetCtx(r) - if !ctx.Auth.IsAdmin && !ctx.NeedsOwnerSetup { + if !ctx.Session.IsAdmin && !ctx.NeedsOwnerSetup { httpx.LoginRedirect(w, *r.URL) return } @@ -115,7 +115,7 @@ func updateOwnerSettings(db *sql.DB, f setupForm) error { func SetupPost(w http.ResponseWriter, r *http.Request) { ctx := httpx.GetCtx(r) - if !ctx.Auth.IsAdmin && !ctx.NeedsOwnerSetup { + if !ctx.Session.IsAdmin && !ctx.NeedsOwnerSetup { httpx.Unauthorized(w) return }