107 lines
2.1 KiB
Go
107 lines
2.1 KiB
Go
package http
|
|
|
|
import (
|
|
"context"
|
|
"database/sql"
|
|
"log"
|
|
"net/http"
|
|
"time"
|
|
|
|
"shelves/backend/routes"
|
|
)
|
|
|
|
type LogWrapper struct {
|
|
statusCode *int
|
|
|
|
http.ResponseWriter
|
|
}
|
|
|
|
func (w LogWrapper) WriteHeader(statusCode int) {
|
|
*w.statusCode = statusCode
|
|
|
|
w.ResponseWriter.WriteHeader(statusCode)
|
|
}
|
|
|
|
func Log(next http.Handler) http.Handler {
|
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
statusCode := 200
|
|
wrapper := LogWrapper{statusCode: &statusCode, ResponseWriter: w}
|
|
start := time.Now()
|
|
next.ServeHTTP(wrapper, r)
|
|
|
|
log.Println(r.Method, r.URL.Path, statusCode, time.Since(start))
|
|
})
|
|
}
|
|
|
|
type AuthInfo struct {
|
|
is_admin bool
|
|
}
|
|
|
|
func checkAuthed(db *sql.DB, r *http.Request) (AuthInfo, error) {
|
|
cookie, err := r.Cookie("SHELVES_OWNER_SESSION_ID")
|
|
if err == http.ErrNoCookie {
|
|
return AuthInfo{}, nil
|
|
}
|
|
if err != nil {
|
|
return AuthInfo{}, err
|
|
}
|
|
|
|
sessionId := cookie.Value
|
|
|
|
sessionCount := 0
|
|
err = db.QueryRow("select count(*) from session where session_id = ?", sessionId).Scan(&sessionCount)
|
|
if err != nil {
|
|
return AuthInfo{}, err
|
|
}
|
|
|
|
return AuthInfo{is_admin: sessionCount == 1}, nil
|
|
}
|
|
|
|
func queryNeedsOwnerSetup(db *sql.DB) (bool, error) {
|
|
count := 0
|
|
err := db.QueryRow("select count(*) from owner_settings").Scan(&count)
|
|
if err != nil {
|
|
return false, err
|
|
}
|
|
|
|
return count != 1, nil
|
|
}
|
|
|
|
type Ctx struct {
|
|
DB *sql.DB
|
|
Auth AuthInfo
|
|
}
|
|
|
|
func WithCtx(db *sql.DB, next http.Handler) http.Handler {
|
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
needsOwnerSetup, err := queryNeedsOwnerSetup(db)
|
|
if err != nil {
|
|
log.Printf("Error while querying owner_settings: %v\n", err)
|
|
}
|
|
|
|
if needsOwnerSetup {
|
|
if r.URL.Path != "/setup" {
|
|
w.Header().Add("Location", "/setup")
|
|
w.WriteHeader(http.StatusSeeOther)
|
|
return
|
|
} else {
|
|
routes.SetupGet(w, r)
|
|
return
|
|
}
|
|
}
|
|
|
|
ctx := r.Context()
|
|
auth, err := checkAuthed(db, r)
|
|
if err != nil {
|
|
log.Printf("Error while querying auth info: %v\n", err)
|
|
}
|
|
ctx = context.WithValue(ctx, "__ctx", Ctx{
|
|
DB: db,
|
|
Auth: auth,
|
|
})
|
|
r = r.WithContext(ctx)
|
|
|
|
next.ServeHTTP(w, r)
|
|
})
|
|
}
|