重构错误处理

This commit is contained in:
xmdhs 2023-10-04 00:36:29 +08:00
parent ab5e87ef42
commit 68ce152d06
No known key found for this signature in database
GPG Key ID: E809D6D43DEFCC95
6 changed files with 71 additions and 21 deletions

View File

@ -3,12 +3,14 @@ package handle
import ( import (
"context" "context"
"encoding/json" "encoding/json"
"log/slog"
"net/http" "net/http"
"github.com/xmdhs/authlib-skin/model" "github.com/xmdhs/authlib-skin/model"
) )
func handleError(ctx context.Context, w http.ResponseWriter, msg string, code model.APIStatus, httpcode int) { func (h *Handel) handleError(ctx context.Context, w http.ResponseWriter, msg string, code model.APIStatus, httpcode int, level slog.Level) {
h.logger.Log(ctx, level, msg)
w.WriteHeader(httpcode) w.WriteHeader(httpcode)
b, err := json.Marshal(model.API[any]{Code: code, Msg: msg, Data: nil}) b, err := json.Marshal(model.API[any]{Code: code, Msg: msg, Data: nil})
if err != nil { if err != nil {

View File

@ -40,14 +40,12 @@ func encodeJson[T any](w io.Writer, m model.API[T]) {
func (h *Handel) getTokenbyAuthorization(ctx context.Context, w http.ResponseWriter, r *http.Request) string { func (h *Handel) getTokenbyAuthorization(ctx context.Context, w http.ResponseWriter, r *http.Request) string {
auth := r.Header.Get("Authorization") auth := r.Header.Get("Authorization")
if auth == "" { if auth == "" {
h.logger.DebugContext(ctx, "缺少 Authorization") h.handleError(ctx, w, "缺少 Authorization", model.ErrAuth, 401, slog.LevelDebug)
handleError(ctx, w, "缺少 Authorization", model.ErrAuth, 401)
return "" return ""
} }
al := strings.Split(auth, " ") al := strings.Split(auth, " ")
if len(al) != 2 || al[0] != "Bearer" { if len(al) != 2 || al[0] != "Bearer" {
h.logger.DebugContext(ctx, "Authorization 格式错误") h.handleError(ctx, w, "Authorization 格式错误", model.ErrAuth, 401, slog.LevelDebug)
handleError(ctx, w, "Authorization 格式错误", model.ErrAuth, 401)
return "" return ""
} }
return al[1] return al[1]

View File

@ -2,6 +2,7 @@ package handle
import ( import (
"errors" "errors"
"log/slog"
"net/http" "net/http"
"github.com/julienschmidt/httprouter" "github.com/julienschmidt/httprouter"
@ -17,37 +18,31 @@ func (h *Handel) Reg() httprouter.Handle {
ip, err := utils.GetIP(r, h.config.RaelIP) ip, err := utils.GetIP(r, h.config.RaelIP)
if err != nil { if err != nil {
h.logger.InfoContext(ctx, err.Error()) h.handleError(ctx, w, err.Error(), model.ErrInput, 400, slog.LevelDebug)
handleError(ctx, w, err.Error(), model.ErrInput, 400)
return return
} }
u, err := utils.DeCodeBody[model.User](r.Body, h.validate) u, err := utils.DeCodeBody[model.User](r.Body, h.validate)
if err != nil { if err != nil {
h.logger.InfoContext(ctx, err.Error()) h.handleError(ctx, w, err.Error(), model.ErrInput, 400, slog.LevelDebug)
handleError(ctx, w, err.Error(), model.ErrInput, 400)
return return
} }
rip, err := getPrefix(ip) rip, err := getPrefix(ip)
if err != nil { if err != nil {
h.logger.WarnContext(ctx, err.Error()) h.handleError(ctx, w, err.Error(), model.ErrUnknown, 500, slog.LevelWarn)
handleError(ctx, w, err.Error(), model.ErrUnknown, 500)
return return
} }
err = h.webService.Reg(ctx, u, rip, ip) err = h.webService.Reg(ctx, u, rip, ip)
if err != nil { if err != nil {
if errors.Is(err, service.ErrExistUser) { if errors.Is(err, service.ErrExistUser) {
h.logger.DebugContext(ctx, err.Error()) h.handleError(ctx, w, err.Error(), model.ErrExistUser, 400, slog.LevelDebug)
handleError(ctx, w, err.Error(), model.ErrExistUser, 400)
return return
} }
if errors.Is(err, service.ErrRegLimit) { if errors.Is(err, service.ErrRegLimit) {
h.logger.DebugContext(ctx, err.Error()) h.handleError(ctx, w, err.Error(), model.ErrRegLimit, 400, slog.LevelDebug)
handleError(ctx, w, err.Error(), model.ErrRegLimit, 400)
return return
} }
h.logger.WarnContext(ctx, err.Error()) h.handleError(ctx, w, err.Error(), model.ErrService, 500, slog.LevelWarn)
handleError(ctx, w, err.Error(), model.ErrService, 500)
return return
} }
encodeJson(w, model.API[any]{ encodeJson(w, model.API[any]{
@ -67,12 +62,10 @@ func (h *Handel) UserInfo() httprouter.Handle {
u, err := h.webService.Info(ctx, token) u, err := h.webService.Info(ctx, token)
if err != nil { if err != nil {
if errors.Is(err, utilsService.ErrTokenInvalid) { if errors.Is(err, utilsService.ErrTokenInvalid) {
h.logger.DebugContext(ctx, "token 无效") h.handleError(ctx, w, "token 无效", model.ErrAuth, 401, slog.LevelDebug)
handleError(ctx, w, "token 无效", model.ErrAuth, 401)
return return
} }
h.logger.InfoContext(ctx, err.Error()) h.handleError(ctx, w, err.Error(), model.ErrService, 500, slog.LevelWarn)
handleError(ctx, w, err.Error(), model.ErrUnknown, 500)
return return
} }
encodeJson(w, model.API[model.UserInfo]{ encodeJson(w, model.API[model.UserInfo]{
@ -81,3 +74,32 @@ func (h *Handel) UserInfo() httprouter.Handle {
}) })
} }
} }
func (h *Handel) ChangePasswd() httprouter.Handle {
return func(w http.ResponseWriter, r *http.Request, p httprouter.Params) {
ctx := r.Context()
token := h.getTokenbyAuthorization(ctx, w, r)
if token == "" {
return
}
c, err := utils.DeCodeBody[model.ChangePasswd](r.Body, h.validate)
if err != nil {
h.handleError(ctx, w, err.Error(), model.ErrInput, 400, slog.LevelDebug)
return
}
err = h.webService.ChangePasswd(ctx, c, token)
if err != nil {
if errors.Is(err, service.ErrPassWord) {
h.handleError(ctx, w, err.Error(), model.ErrPassWord, 401, slog.LevelDebug)
return
}
h.handleError(ctx, w, err.Error(), model.ErrService, 500, slog.LevelWarn)
return
}
encodeJson(w, model.API[any]{
Code: 0,
})
}
}

View File

@ -10,4 +10,5 @@ const (
ErrExistUser ErrExistUser
ErrRegLimit ErrRegLimit
ErrAuth ErrAuth
ErrPassWord
) )

View File

@ -35,3 +35,8 @@ type UserInfo struct {
UUID string `json:"uuid"` UUID string `json:"uuid"`
IsAdmin bool `json:"is_admin"` IsAdmin bool `json:"is_admin"`
} }
type ChangePasswd struct {
Old string `json:"old"`
New string `json:"new"`
}

View File

@ -21,6 +21,7 @@ var (
ErrExistUser = errors.New("邮箱已存在") ErrExistUser = errors.New("邮箱已存在")
ErrExitsName = errors.New("用户名已存在") ErrExitsName = errors.New("用户名已存在")
ErrRegLimit = errors.New("超过注册 ip 限制") ErrRegLimit = errors.New("超过注册 ip 限制")
ErrPassWord = errors.New("错误的密码")
) )
func (w *WebService) Reg(ctx context.Context, u model.User, ipPrefix, ip string) error { func (w *WebService) Reg(ctx context.Context, u model.User, ipPrefix, ip string) error {
@ -110,3 +111,24 @@ func (w *WebService) Info(ctx context.Context, token string) (model.UserInfo, er
IsAdmin: isAdmin, IsAdmin: isAdmin,
}, nil }, nil
} }
func (w *WebService) ChangePasswd(ctx context.Context, p model.ChangePasswd, token string) error {
t, err := utilsService.Auth(ctx, yggdrasil.ValidateToken{AccessToken: token}, w.client, w.cache, &w.prikey.PublicKey, false)
if err != nil {
return fmt.Errorf("ChangePasswd: %w", err)
}
u, err := w.client.User.Query().Where(user.IDEQ(t.UID)).First(ctx)
if err != nil {
return fmt.Errorf("ChangePasswd: %w", err)
}
if !utils.Argon2Compare(p.Old, u.Password, u.Salt) {
return fmt.Errorf("ChangePasswd: %w", ErrPassWord)
}
pass, salt := utils.Argon2ID(p.New)
err = w.client.User.UpdateOne(u).SetPassword(pass).SetSalt(salt).Exec(ctx)
if err != nil {
return fmt.Errorf("ChangePasswd: %w", err)
}
return nil
}