shelves/backend/http/middleware.go
2024-11-14 17:39:31 -05:00

107 lines
2.1 KiB
Go

package http
import (
"context"
"database/sql"
"log"
"net/http"
"time"
"shelves/backend/routes"
)
type LogWrapper struct {
statusCode *int
http.ResponseWriter
}
func (w LogWrapper) WriteHeader(statusCode int) {
*w.statusCode = statusCode
w.ResponseWriter.WriteHeader(statusCode)
}
func Log(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
statusCode := 200
wrapper := LogWrapper{statusCode: &statusCode, ResponseWriter: w}
start := time.Now()
next.ServeHTTP(wrapper, r)
log.Println(r.Method, r.URL.Path, statusCode, time.Since(start))
})
}
type AuthInfo struct {
is_admin bool
}
func checkAuthed(db *sql.DB, r *http.Request) (AuthInfo, error) {
cookie, err := r.Cookie("SHELVES_OWNER_SESSION_ID")
if err == http.ErrNoCookie {
return AuthInfo{}, nil
}
if err != nil {
return AuthInfo{}, err
}
sessionId := cookie.Value
sessionCount := 0
err = db.QueryRow("select count(*) from session where session_id = ?", sessionId).Scan(&sessionCount)
if err != nil {
return AuthInfo{}, err
}
return AuthInfo{is_admin: sessionCount == 1}, nil
}
func queryNeedsOwnerSetup(db *sql.DB) (bool, error) {
count := 0
err := db.QueryRow("select count(*) from owner_settings").Scan(&count)
if err != nil {
return false, err
}
return count != 1, nil
}
type Ctx struct {
DB *sql.DB
Auth AuthInfo
}
func WithCtx(db *sql.DB, next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
needsOwnerSetup, err := queryNeedsOwnerSetup(db)
if err != nil {
log.Printf("Error while querying owner_settings: %v\n", err)
}
if needsOwnerSetup {
if r.URL.Path != "/setup" {
w.Header().Add("Location", "/setup")
w.WriteHeader(http.StatusSeeOther)
return
} else {
routes.SetupGet(w, r)
return
}
}
ctx := r.Context()
auth, err := checkAuthed(db, r)
if err != nil {
log.Printf("Error while querying auth info: %v\n", err)
}
ctx = context.WithValue(ctx, "__ctx", Ctx{
DB: db,
Auth: auth,
})
r = r.WithContext(ctx)
next.ServeHTTP(w, r)
})
}