package routes import ( "context" "database/sql" "html/template" "log" "net/http" "strings" "time" "git.soup.land/soup/shelves" "git.soup.land/soup/shelves/internal/auth" "git.soup.land/soup/shelves/internal/httpx" ) func html(w http.ResponseWriter, s template.HTML) { w.Header().Add("Content-Type", "text/html") w.Write([]byte(s)) } func redirectSetup(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { ctx := getCtx(r) if ctx.NeedsOwnerSetup && !strings.HasPrefix(r.URL.Path, "/static") { if r.URL.Path != "/setup" { httpx.SeeOther(w, "/setup") return } } next.ServeHTTP(w, r) }) } func serveStatic() http.Handler { inner := http.StripPrefix("/static/", http.FileServerFS(shelves.Frontend)) startup := time.Now().UTC() return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { ims := r.Header.Get("If-Modified-Since") imsTime, _ := time.Parse(http.TimeFormat, ims) imsTime = imsTime.Add(time.Second * 10) if imsTime.Before(startup) { w.Header().Add("Last-Modified", startup.Format(http.TimeFormat)) w.Header().Add("Cache-Control", "public, max-age=0, stale-while-revalidate=9999999") inner.ServeHTTP(w, r) } else { w.WriteHeader(http.StatusNotModified) } }) } func seqHandler(a, b http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { wt := httpx.NewResponseWriterTracker(w) a.ServeHTTP(&wt, r) if !wt.WasWritten() { b.ServeHTTP(&wt, r) } }) } func Routes(db *sql.DB) http.Handler { public := http.NewServeMux() public.Handle("GET /static/", serveStatic()) public.HandleFunc("GET /{$}", homeGet) public.HandleFunc("GET /login", loginGet) public.HandleFunc("POST /login", loginPost) public.HandleFunc("DELETE /login", loginDelete) public.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) {}) ownerOnly := http.NewServeMux() ownerOnly.HandleFunc("GET /settings", settingsGet) ownerOnly.HandleFunc("POST /settings", settingsPost) ownerOnly.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) {}) h := seqHandler(public, ownerOnly) h = seqHandler(h, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusNotFound) w.Write([]byte("404 not found")) })) return withCtx(db, redirectSetup(h)) } func queryNeedsOwnerSetup(db *sql.DB) (bool, error) { count := 0 err := db.QueryRow("select count(*) from owner_settings").Scan(&count) if err != nil { return false, err } return count != 1, nil } func sessionInfo(db *sql.DB, r *http.Request) (auth.SessionInfo, error) { cookie, err := r.Cookie("SHELVES_OWNER_SESSION_ID") if err == http.ErrNoCookie { return auth.SessionInfo{}, nil } if err != nil { return auth.SessionInfo{}, err } return auth.SessionCheck(db, cookie.Value) } type Ctx struct { DB *sql.DB SessionInfo auth.SessionInfo NeedsOwnerSetup bool } func getCtx(r *http.Request) Ctx { cx := r.Context() return cx.Value("__ctx").(Ctx) } func withCtx(db *sql.DB, next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { needsOwnerSetup, err := queryNeedsOwnerSetup(db) if err != nil { log.Printf("Error while querying owner_settings: %v\n", err) } ctx := r.Context() session, err := sessionInfo(db, r) if err != nil { log.Printf("Error while querying session info: %v\n", err) } ctx = context.WithValue(ctx, "__ctx", Ctx{ DB: db, SessionInfo: session, NeedsOwnerSetup: needsOwnerSetup, }) r = r.WithContext(ctx) next.ServeHTTP(w, r) }) }