diff --git a/handle/admin.go b/handle/admin.go index ab657e1..4951cb0 100644 --- a/handle/admin.go +++ b/handle/admin.go @@ -1,26 +1,50 @@ package handle import ( + "context" "errors" "log/slog" "net/http" "strconv" "github.com/xmdhs/authlib-skin/model" - utilsService "github.com/xmdhs/authlib-skin/service/utils" + "github.com/xmdhs/authlib-skin/service" + "github.com/xmdhs/authlib-skin/service/utils" ) -func (h *Handel) NeedAdmin(handle http.Handler) http.Handler { +type tokenValue string + +const tokenKey = tokenValue("token") + +func (h *Handel) NeedAuth(handle http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { ctx := r.Context() token := h.getTokenbyAuthorization(ctx, w, r) if token == "" { return } - err := h.webService.IsAdmin(ctx, token) + t, err := h.webService.Auth(ctx, token) if err != nil { - if errors.Is(err, utilsService.ErrTokenInvalid) { - h.handleError(ctx, w, "token 无效", model.ErrAuth, 401, slog.LevelDebug) + if errors.Is(err, utils.ErrTokenInvalid) { + h.handleError(ctx, w, err.Error(), model.ErrAuth, 401, slog.LevelDebug) + return + } + h.handleError(ctx, w, err.Error(), model.ErrService, 500, slog.LevelWarn) + return + } + r = r.WithContext(context.WithValue(ctx, tokenKey, t)) + handle.ServeHTTP(w, r) + }) +} + +func (h *Handel) NeedAdmin(handle http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() + t := ctx.Value(tokenKey).(*model.TokenClaims) + err := h.webService.IsAdmin(ctx, t) + if err != nil { + if errors.Is(err, service.ErrNotAdmin) { + h.handleError(ctx, w, err.Error(), model.ErrNotAdmin, 401, slog.LevelDebug) return } h.handleError(ctx, w, err.Error(), model.ErrService, 500, slog.LevelWarn) diff --git a/handle/user.go b/handle/user.go index a51f61a..f7c6a51 100644 --- a/handle/user.go +++ b/handle/user.go @@ -57,12 +57,8 @@ func (h *Handel) Reg() http.HandlerFunc { func (h *Handel) UserInfo() http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { ctx := r.Context() - token := h.getTokenbyAuthorization(ctx, w, r) - if token == "" { - return - } - - u, err := h.webService.Info(ctx, token) + t := ctx.Value(tokenKey).(*model.TokenClaims) + u, err := h.webService.Info(ctx, t) if err != nil { if errors.Is(err, utilsService.ErrTokenInvalid) { h.handleError(ctx, w, "token 无效", model.ErrAuth, 401, slog.LevelDebug) @@ -81,17 +77,14 @@ func (h *Handel) UserInfo() http.HandlerFunc { func (h *Handel) ChangePasswd() http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { ctx := r.Context() - token := h.getTokenbyAuthorization(ctx, w, r) - if token == "" { - return - } + t := ctx.Value(tokenKey).(*model.TokenClaims) c, err := utils.DeCodeBody[model.ChangePasswd](r.Body, h.validate) if err != nil { h.handleError(ctx, w, err.Error(), model.ErrInput, 400, slog.LevelDebug) return } - err = h.webService.ChangePasswd(ctx, c, token) + err = h.webService.ChangePasswd(ctx, c, t) if err != nil { if errors.Is(err, service.ErrPassWord) { h.handleError(ctx, w, err.Error(), model.ErrPassWord, 401, slog.LevelDebug) @@ -110,16 +103,13 @@ func (h *Handel) ChangePasswd() http.HandlerFunc { func (h *Handel) ChangeName() http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { ctx := r.Context() - token := h.getTokenbyAuthorization(ctx, w, r) - if token == "" { - return - } + t := ctx.Value(tokenKey).(*model.TokenClaims) c, err := utils.DeCodeBody[model.ChangeName](r.Body, h.validate) if err != nil { h.handleError(ctx, w, err.Error(), model.ErrInput, 400, slog.LevelDebug) return } - err = h.webService.ChangeName(ctx, c.Name, token) + err = h.webService.ChangeName(ctx, c.Name, t) if err != nil { if errors.Is(err, service.ErrExitsName) { h.handleError(ctx, w, err.Error(), model.ErrExitsName, 400, slog.LevelDebug) diff --git a/model/const.go b/model/const.go index 893d3e3..bab6766 100644 --- a/model/const.go +++ b/model/const.go @@ -12,4 +12,5 @@ const ( ErrAuth ErrPassWord ErrExitsName + ErrNotAdmin ) diff --git a/server/route/route.go b/server/route/route.go index 6ed5304..940c4ae 100644 --- a/server/route/route.go +++ b/server/route/route.go @@ -60,14 +60,18 @@ func newSkinApi(handel *handle.Handel) http.Handler { r.Put("/user/reg", handel.Reg()) r.Get("/config", handel.GetConfig()) - r.Get("/user", handel.UserInfo()) - r.Post("/user/password", handel.ChangePasswd()) - r.Post("/user/name", handel.ChangeName()) r.Group(func(r chi.Router) { - r.Use(handel.NeedAdmin) - r.Get("/admin/users", handel.ListUser()) + r.Use(handel.NeedAuth) + r.Get("/user", handel.UserInfo()) + r.Post("/user/password", handel.ChangePasswd()) + r.Post("/user/name", handel.ChangeName()) + r.Group(func(r chi.Router) { + r.Use(handel.NeedAdmin) + r.Get("/admin/users", handel.ListUser()) + }) }) + return r } diff --git a/service/admin.go b/service/admin.go index 9f84e46..20f6922 100644 --- a/service/admin.go +++ b/service/admin.go @@ -13,11 +13,15 @@ import ( var ErrNotAdmin = errors.New("无权限") -func (w *WebService) IsAdmin(ctx context.Context, token string) error { +func (w *WebService) Auth(ctx context.Context, token string) (*model.TokenClaims, error) { t, err := utilsService.Auth(ctx, yggdrasil.ValidateToken{AccessToken: token}, w.client, w.cache, &w.prikey.PublicKey, false) if err != nil { - return fmt.Errorf("IsAdmin: %w", err) + return nil, fmt.Errorf("Auth: %w", err) } + return t, nil +} + +func (w *WebService) IsAdmin(ctx context.Context, t *model.TokenClaims) error { u, err := w.client.User.Query().Where(user.ID(t.UID)).First(ctx) if err != nil { return fmt.Errorf("IsAdmin: %w", err) diff --git a/service/user.go b/service/user.go index 7612743..dd7b480 100644 --- a/service/user.go +++ b/service/user.go @@ -13,7 +13,6 @@ import ( "github.com/xmdhs/authlib-skin/db/ent/user" "github.com/xmdhs/authlib-skin/db/ent/userprofile" "github.com/xmdhs/authlib-skin/model" - "github.com/xmdhs/authlib-skin/model/yggdrasil" utilsService "github.com/xmdhs/authlib-skin/service/utils" "github.com/xmdhs/authlib-skin/utils" ) @@ -94,11 +93,7 @@ func (w *WebService) Reg(ctx context.Context, u model.User, ipPrefix, ip string) return nil } -func (w *WebService) Info(ctx context.Context, token string) (model.UserInfo, error) { - t, err := utilsService.Auth(ctx, yggdrasil.ValidateToken{AccessToken: token}, w.client, w.cache, &w.prikey.PublicKey, false) - if err != nil { - return model.UserInfo{}, fmt.Errorf("Info: %w", err) - } +func (w *WebService) Info(ctx context.Context, t *model.TokenClaims) (model.UserInfo, error) { u, err := w.client.User.Query().Where(user.ID(t.UID)).First(ctx) if err != nil { return model.UserInfo{}, fmt.Errorf("Info: %w", err) @@ -111,11 +106,7 @@ func (w *WebService) Info(ctx context.Context, token string) (model.UserInfo, er }, nil } -func (w *WebService) ChangePasswd(ctx context.Context, p model.ChangePasswd, token string) error { - t, err := utilsService.Auth(ctx, yggdrasil.ValidateToken{AccessToken: token}, w.client, w.cache, &w.prikey.PublicKey, false) - if err != nil { - return fmt.Errorf("ChangePasswd: %w", err) - } +func (w *WebService) ChangePasswd(ctx context.Context, p model.ChangePasswd, t *model.TokenClaims) error { u, err := w.client.User.Query().Where(user.IDEQ(t.UID)).WithToken().First(ctx) if err != nil { return fmt.Errorf("ChangePasswd: %w", err) @@ -160,12 +151,8 @@ func (w *WebService) changeName(ctx context.Context, newName string, uid int, uu return err } -func (w *WebService) ChangeName(ctx context.Context, newName string, token string) error { - t, err := utilsService.Auth(ctx, yggdrasil.ValidateToken{AccessToken: token}, w.client, w.cache, &w.prikey.PublicKey, false) - if err != nil { - return fmt.Errorf("ChangeName: %w", err) - } - err = w.changeName(ctx, newName, t.UID, t.Subject) +func (w *WebService) ChangeName(ctx context.Context, newName string, t *model.TokenClaims) error { + err := w.changeName(ctx, newName, t.UID, t.Subject) if err != nil { return fmt.Errorf("ChangeName: %w", err) }