package httpx import ( "context" "database/sql" "log" "net/http" "time" ) type LogWrapper struct { statusCode *int http.ResponseWriter } func (w LogWrapper) WriteHeader(statusCode int) { *w.statusCode = statusCode w.ResponseWriter.WriteHeader(statusCode) } func Log(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { statusCode := 200 wrapper := LogWrapper{statusCode: &statusCode, ResponseWriter: w} start := time.Now() next.ServeHTTP(wrapper, r) log.Println(r.Method, r.URL.Path, statusCode, time.Since(start)) }) } type AuthInfo struct { IsAdmin bool SessionId string } func queryCheckAuth(db *sql.DB, sessionId string) (AuthInfo, error) { sessionCount := 0 err := db.QueryRow("select count(*) from session where session_id = ?", sessionId).Scan(&sessionCount) if err != nil { return AuthInfo{}, err } return AuthInfo{IsAdmin: sessionCount == 1, SessionId: sessionId}, nil } func checkAuthed(db *sql.DB, r *http.Request) (AuthInfo, error) { cookie, err := r.Cookie("SHELVES_OWNER_SESSION_ID") if err == http.ErrNoCookie { return AuthInfo{}, nil } if err != nil { return AuthInfo{}, err } sessionId := cookie.Value return queryCheckAuth(db, sessionId) } func queryNeedsOwnerSetup(db *sql.DB) (bool, error) { count := 0 err := db.QueryRow("select count(*) from owner_settings").Scan(&count) if err != nil { return false, err } return count != 1, nil } type Ctx struct { DB *sql.DB Auth AuthInfo NeedsOwnerSetup bool } func GetCtx(r *http.Request) Ctx { ctx := r.Context() return ctx.Value("__ctx").(Ctx) } func WithCtx(db *sql.DB, next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { needsOwnerSetup, err := queryNeedsOwnerSetup(db) if err != nil { log.Printf("Error while querying owner_settings: %v\n", err) } ctx := r.Context() auth, err := checkAuthed(db, r) if err != nil { log.Printf("Error while querying auth info: %v\n", err) } ctx = context.WithValue(ctx, "__ctx", Ctx{ DB: db, Auth: auth, NeedsOwnerSetup: needsOwnerSetup, }) r = r.WithContext(ctx) next.ServeHTTP(w, r) }) }