diff --git a/internal/httpx/middleware.go b/internal/httpx/middleware.go index 7e6ae34..1a3d980 100644 --- a/internal/httpx/middleware.go +++ b/internal/httpx/middleware.go @@ -1,90 +1,44 @@ package httpx import ( - "context" - "database/sql" - "git.soup.land/soup/shelves/internal/auth" "log" "net/http" "time" ) -type LogWrapper struct { - statusCode *int +type ResponseWriterTracker struct { + StatusCode int + Wrote bool http.ResponseWriter } -func (w LogWrapper) WriteHeader(statusCode int) { - *w.statusCode = statusCode +func (w ResponseWriterTracker) WasWritten() bool { + return w.StatusCode != 0 || w.Wrote +} + +func (w *ResponseWriterTracker) WriteHeader(statusCode int) { + w.StatusCode = statusCode w.ResponseWriter.WriteHeader(statusCode) } +func (w *ResponseWriterTracker) Write(b []byte) (int, error) { + w.Wrote = true + + return w.ResponseWriter.Write(b) +} + +func NewResponseWriterTracker(w http.ResponseWriter) ResponseWriterTracker { + return ResponseWriterTracker{ResponseWriter: w} +} + func Log(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - statusCode := 200 - wrapper := LogWrapper{statusCode: &statusCode, ResponseWriter: w} + wt := NewResponseWriterTracker(w) start := time.Now() - next.ServeHTTP(wrapper, r) + next.ServeHTTP(&wt, r) - log.Println(r.Method, r.URL.Path, statusCode, time.Since(start)) - }) -} - -func queryNeedsOwnerSetup(db *sql.DB) (bool, error) { - count := 0 - err := db.QueryRow("select count(*) from owner_settings").Scan(&count) - if err != nil { - return false, err - } - - return count != 1, nil -} - -func sessionInfo(db *sql.DB, r *http.Request) (auth.SessionInfo, error) { - cookie, err := r.Cookie("SHELVES_OWNER_SESSION_ID") - if err == http.ErrNoCookie { - return auth.SessionInfo{}, nil - } - if err != nil { - return auth.SessionInfo{}, err - } - - return auth.SessionCheck(db, cookie.Value) -} - -type Ctx struct { - DB *sql.DB - SessionInfo auth.SessionInfo - NeedsOwnerSetup bool -} - -func GetCtx(r *http.Request) Ctx { - ctx := r.Context() - - return ctx.Value("__ctx").(Ctx) -} - -func WithCtx(db *sql.DB, next http.Handler) http.Handler { - return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - needsOwnerSetup, err := queryNeedsOwnerSetup(db) - if err != nil { - log.Printf("Error while querying owner_settings: %v\n", err) - } - - ctx := r.Context() - session, err := sessionInfo(db, r) - if err != nil { - log.Printf("Error while querying session info: %v\n", err) - } - ctx = context.WithValue(ctx, "__ctx", Ctx{ - DB: db, - SessionInfo: session, - NeedsOwnerSetup: needsOwnerSetup, - }) - r = r.WithContext(ctx) - - next.ServeHTTP(w, r) + log.Println(r.Method, r.URL.Path, wt.StatusCode, time.Since(start)) }) } diff --git a/internal/routes/home.go b/internal/routes/home.go index 15ac3de..56c57a4 100644 --- a/internal/routes/home.go +++ b/internal/routes/home.go @@ -3,7 +3,6 @@ package routes import ( "net/http" - "git.soup.land/soup/shelves/internal/httpx" "git.soup.land/soup/shelves/internal/templates" "git.soup.land/soup/shelves/internal/templates/components" ) @@ -13,8 +12,8 @@ type homeContent struct { var homeTmpl = templates.MustParseEmbed("views/home.tmpl.html") -func HomeGet(w http.ResponseWriter, req *http.Request) { - ctx := httpx.GetCtx(req) +func homeGet(w http.ResponseWriter, req *http.Request) { + ctx := getCtx(req) h := components.Page{ Title: "Home", diff --git a/internal/routes/login.go b/internal/routes/login.go index f762736..a2e8087 100644 --- a/internal/routes/login.go +++ b/internal/routes/login.go @@ -22,7 +22,7 @@ type loginFormErrors struct { password error } -func loginRenderView(f loginForm, e loginFormErrors, ctx httpx.Ctx) template.HTML { +func loginRenderView(f loginForm, e loginFormErrors, ctx Ctx) template.HTML { formHtml := components.Form{ Action: "/login", Fields: []components.Field{ @@ -64,8 +64,8 @@ func getRedirectTo(r *http.Request) string { return redirectTo } -func LoginGet(w http.ResponseWriter, r *http.Request) { - ctx := httpx.GetCtx(r) +func loginGet(w http.ResponseWriter, r *http.Request) { + ctx := getCtx(r) redirectTo := getRedirectTo(r) if ctx.SessionInfo.IsAdmin { @@ -81,8 +81,8 @@ func loginParseForm(f *loginForm, e *loginFormErrors, vs url.Values, v *forms.Va f.redirectTo = vs.Get("redirectTo") } -func LoginPost(w http.ResponseWriter, r *http.Request) { - ctx := httpx.GetCtx(r) +func loginPost(w http.ResponseWriter, r *http.Request) { + ctx := getCtx(r) form := loginForm{} errs := loginFormErrors{} @@ -123,8 +123,8 @@ func LoginPost(w http.ResponseWriter, r *http.Request) { httpx.SeeOther(w, form.redirectTo) } -func LoginDelete(w http.ResponseWriter, r *http.Request) { - ctx := httpx.GetCtx(r) +func loginDelete(w http.ResponseWriter, r *http.Request) { + ctx := getCtx(r) httpx.HxRefresh(w) if !ctx.SessionInfo.IsAdmin { diff --git a/internal/routes/routes.go b/internal/routes/routes.go index 41cc9fd..aca6e27 100644 --- a/internal/routes/routes.go +++ b/internal/routes/routes.go @@ -1,12 +1,17 @@ package routes import ( - "git.soup.land/soup/shelves" - "git.soup.land/soup/shelves/internal/httpx" + "context" + "database/sql" "html/template" + "log" "net/http" "strings" "time" + + "git.soup.land/soup/shelves" + "git.soup.land/soup/shelves/internal/auth" + "git.soup.land/soup/shelves/internal/httpx" ) func html(w http.ResponseWriter, s template.HTML) { @@ -16,7 +21,7 @@ func html(w http.ResponseWriter, s template.HTML) { func redirectSetup(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - ctx := httpx.GetCtx(r) + ctx := getCtx(r) if ctx.NeedsOwnerSetup && !strings.HasPrefix(r.URL.Path, "/static") { if r.URL.Path != "/setup" { httpx.SeeOther(w, "/setup") @@ -28,7 +33,7 @@ func redirectSetup(next http.Handler) http.Handler { }) } -func ServeStatic() http.Handler { +func serveStatic() http.Handler { inner := http.StripPrefix("/static/", http.FileServerFS(shelves.Frontend)) startup := time.Now().UTC() @@ -47,15 +52,93 @@ func ServeStatic() http.Handler { }) } -func Routes() http.Handler { - mux := http.NewServeMux() - mux.Handle("GET /static/", ServeStatic()) - mux.HandleFunc("GET /{$}", HomeGet) - mux.HandleFunc("GET /settings", SettingsGet) - mux.HandleFunc("POST /settings", SettingsPost) - mux.HandleFunc("GET /login", LoginGet) - mux.HandleFunc("POST /login", LoginPost) - mux.HandleFunc("DELETE /login", LoginDelete) +func seqHandler(a, b http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + wt := httpx.NewResponseWriterTracker(w) - return redirectSetup(mux) + a.ServeHTTP(&wt, r) + if !wt.WasWritten() { + b.ServeHTTP(&wt, r) + } + }) +} + +func Routes(db *sql.DB) http.Handler { + public := http.NewServeMux() + public.Handle("GET /static/", serveStatic()) + public.HandleFunc("GET /{$}", homeGet) + public.HandleFunc("GET /login", loginGet) + public.HandleFunc("POST /login", loginPost) + public.HandleFunc("DELETE /login", loginDelete) + public.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) {}) + + ownerOnly := http.NewServeMux() + ownerOnly.HandleFunc("GET /settings", settingsGet) + ownerOnly.HandleFunc("POST /settings", settingsPost) + ownerOnly.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) {}) + + h := seqHandler(public, ownerOnly) + h = seqHandler(h, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusNotFound) + w.Write([]byte("404 not found")) + })) + + return withCtx(db, redirectSetup(h)) +} + +func queryNeedsOwnerSetup(db *sql.DB) (bool, error) { + count := 0 + err := db.QueryRow("select count(*) from owner_settings").Scan(&count) + if err != nil { + return false, err + } + + return count != 1, nil +} + +func sessionInfo(db *sql.DB, r *http.Request) (auth.SessionInfo, error) { + cookie, err := r.Cookie("SHELVES_OWNER_SESSION_ID") + if err == http.ErrNoCookie { + return auth.SessionInfo{}, nil + } + if err != nil { + return auth.SessionInfo{}, err + } + + return auth.SessionCheck(db, cookie.Value) +} + +type Ctx struct { + DB *sql.DB + SessionInfo auth.SessionInfo + NeedsOwnerSetup bool +} + +func getCtx(r *http.Request) Ctx { + cx := r.Context() + + return cx.Value("__ctx").(Ctx) +} + +func withCtx(db *sql.DB, next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + needsOwnerSetup, err := queryNeedsOwnerSetup(db) + if err != nil { + log.Printf("Error while querying owner_settings: %v\n", err) + } + + ctx := r.Context() + session, err := sessionInfo(db, r) + if err != nil { + log.Printf("Error while querying session info: %v\n", err) + } + ctx = context.WithValue(ctx, "__ctx", Ctx{ + DB: db, + SessionInfo: session, + NeedsOwnerSetup: needsOwnerSetup, + }) + r = r.WithContext(ctx) + + next.ServeHTTP(w, r) + }) } diff --git a/internal/routes/settings.go b/internal/routes/settings.go index f6e3275..37aa501 100644 --- a/internal/routes/settings.go +++ b/internal/routes/settings.go @@ -14,7 +14,7 @@ import ( "git.soup.land/soup/shelves/internal/templates/components" ) -func settingsRenderView(f settingsForm, e settingsFormErrors, ctx httpx.Ctx) template.HTML { +func settingsRenderView(f settingsForm, e settingsFormErrors, ctx Ctx) template.HTML { body := components.Form{ Action: "/settings", Fields: []components.Field{ @@ -54,8 +54,8 @@ func queryGetOwnerSettings(db *sql.DB) (settingsForm, error) { return form, err } -func SettingsGet(w http.ResponseWriter, r *http.Request) { - ctx := httpx.GetCtx(r) +func settingsGet(w http.ResponseWriter, r *http.Request) { + ctx := getCtx(r) if !ctx.SessionInfo.IsAdmin && !ctx.NeedsOwnerSetup { httpx.LoginRedirect(w, *r.URL) return @@ -113,8 +113,8 @@ func updateOwnerSettings(db *sql.DB, f settingsForm) error { return queryUpdateOwnerSettings(db, f.displayName, salt, hash) } -func SettingsPost(w http.ResponseWriter, r *http.Request) { - ctx := httpx.GetCtx(r) +func settingsPost(w http.ResponseWriter, r *http.Request) { + ctx := getCtx(r) if !ctx.SessionInfo.IsAdmin && !ctx.NeedsOwnerSetup { httpx.Unauthorized(w)