diff --git a/handle/error.go b/handle/error.go index 828d52f..e1c2a3a 100644 --- a/handle/error.go +++ b/handle/error.go @@ -3,12 +3,14 @@ package handle import ( "context" "encoding/json" + "log/slog" "net/http" "github.com/xmdhs/authlib-skin/model" ) -func handleError(ctx context.Context, w http.ResponseWriter, msg string, code model.APIStatus, httpcode int) { +func (h *Handel) handleError(ctx context.Context, w http.ResponseWriter, msg string, code model.APIStatus, httpcode int, level slog.Level) { + h.logger.Log(ctx, level, msg) w.WriteHeader(httpcode) b, err := json.Marshal(model.API[any]{Code: code, Msg: msg, Data: nil}) if err != nil { diff --git a/handle/handle.go b/handle/handle.go index b28911e..d0edf16 100644 --- a/handle/handle.go +++ b/handle/handle.go @@ -40,14 +40,12 @@ func encodeJson[T any](w io.Writer, m model.API[T]) { func (h *Handel) getTokenbyAuthorization(ctx context.Context, w http.ResponseWriter, r *http.Request) string { auth := r.Header.Get("Authorization") if auth == "" { - h.logger.DebugContext(ctx, "缺少 Authorization") - handleError(ctx, w, "缺少 Authorization", model.ErrAuth, 401) + h.handleError(ctx, w, "缺少 Authorization", model.ErrAuth, 401, slog.LevelDebug) return "" } al := strings.Split(auth, " ") if len(al) != 2 || al[0] != "Bearer" { - h.logger.DebugContext(ctx, "Authorization 格式错误") - handleError(ctx, w, "Authorization 格式错误", model.ErrAuth, 401) + h.handleError(ctx, w, "Authorization 格式错误", model.ErrAuth, 401, slog.LevelDebug) return "" } return al[1] diff --git a/handle/user.go b/handle/user.go index 1bf7782..b832fe4 100644 --- a/handle/user.go +++ b/handle/user.go @@ -2,6 +2,7 @@ package handle import ( "errors" + "log/slog" "net/http" "github.com/julienschmidt/httprouter" @@ -17,37 +18,31 @@ func (h *Handel) Reg() httprouter.Handle { ip, err := utils.GetIP(r, h.config.RaelIP) if err != nil { - h.logger.InfoContext(ctx, err.Error()) - handleError(ctx, w, err.Error(), model.ErrInput, 400) + h.handleError(ctx, w, err.Error(), model.ErrInput, 400, slog.LevelDebug) return } u, err := utils.DeCodeBody[model.User](r.Body, h.validate) if err != nil { - h.logger.InfoContext(ctx, err.Error()) - handleError(ctx, w, err.Error(), model.ErrInput, 400) + h.handleError(ctx, w, err.Error(), model.ErrInput, 400, slog.LevelDebug) return } rip, err := getPrefix(ip) if err != nil { - h.logger.WarnContext(ctx, err.Error()) - handleError(ctx, w, err.Error(), model.ErrUnknown, 500) + h.handleError(ctx, w, err.Error(), model.ErrUnknown, 500, slog.LevelWarn) return } err = h.webService.Reg(ctx, u, rip, ip) if err != nil { if errors.Is(err, service.ErrExistUser) { - h.logger.DebugContext(ctx, err.Error()) - handleError(ctx, w, err.Error(), model.ErrExistUser, 400) + h.handleError(ctx, w, err.Error(), model.ErrExistUser, 400, slog.LevelDebug) return } if errors.Is(err, service.ErrRegLimit) { - h.logger.DebugContext(ctx, err.Error()) - handleError(ctx, w, err.Error(), model.ErrRegLimit, 400) + h.handleError(ctx, w, err.Error(), model.ErrRegLimit, 400, slog.LevelDebug) return } - h.logger.WarnContext(ctx, err.Error()) - handleError(ctx, w, err.Error(), model.ErrService, 500) + h.handleError(ctx, w, err.Error(), model.ErrService, 500, slog.LevelWarn) return } encodeJson(w, model.API[any]{ @@ -67,12 +62,10 @@ func (h *Handel) UserInfo() httprouter.Handle { u, err := h.webService.Info(ctx, token) if err != nil { if errors.Is(err, utilsService.ErrTokenInvalid) { - h.logger.DebugContext(ctx, "token 无效") - handleError(ctx, w, "token 无效", model.ErrAuth, 401) + h.handleError(ctx, w, "token 无效", model.ErrAuth, 401, slog.LevelDebug) return } - h.logger.InfoContext(ctx, err.Error()) - handleError(ctx, w, err.Error(), model.ErrUnknown, 500) + h.handleError(ctx, w, err.Error(), model.ErrService, 500, slog.LevelWarn) return } encodeJson(w, model.API[model.UserInfo]{ @@ -81,3 +74,32 @@ func (h *Handel) UserInfo() httprouter.Handle { }) } } + +func (h *Handel) ChangePasswd() httprouter.Handle { + return func(w http.ResponseWriter, r *http.Request, p httprouter.Params) { + ctx := r.Context() + token := h.getTokenbyAuthorization(ctx, w, r) + if token == "" { + return + } + + c, err := utils.DeCodeBody[model.ChangePasswd](r.Body, h.validate) + if err != nil { + h.handleError(ctx, w, err.Error(), model.ErrInput, 400, slog.LevelDebug) + return + } + err = h.webService.ChangePasswd(ctx, c, token) + if err != nil { + if errors.Is(err, service.ErrPassWord) { + h.handleError(ctx, w, err.Error(), model.ErrPassWord, 401, slog.LevelDebug) + return + } + h.handleError(ctx, w, err.Error(), model.ErrService, 500, slog.LevelWarn) + return + } + encodeJson(w, model.API[any]{ + Code: 0, + }) + + } +} diff --git a/model/const.go b/model/const.go index 4ab0f5f..98d8061 100644 --- a/model/const.go +++ b/model/const.go @@ -10,4 +10,5 @@ const ( ErrExistUser ErrRegLimit ErrAuth + ErrPassWord ) diff --git a/model/model.go b/model/model.go index 5c7edf9..c231eb5 100644 --- a/model/model.go +++ b/model/model.go @@ -35,3 +35,8 @@ type UserInfo struct { UUID string `json:"uuid"` IsAdmin bool `json:"is_admin"` } + +type ChangePasswd struct { + Old string `json:"old"` + New string `json:"new"` +} diff --git a/service/user.go b/service/user.go index cb47e53..eca6cd2 100644 --- a/service/user.go +++ b/service/user.go @@ -21,6 +21,7 @@ var ( ErrExistUser = errors.New("邮箱已存在") ErrExitsName = errors.New("用户名已存在") ErrRegLimit = errors.New("超过注册 ip 限制") + ErrPassWord = errors.New("错误的密码") ) func (w *WebService) Reg(ctx context.Context, u model.User, ipPrefix, ip string) error { @@ -110,3 +111,24 @@ func (w *WebService) Info(ctx context.Context, token string) (model.UserInfo, er IsAdmin: isAdmin, }, nil } + +func (w *WebService) ChangePasswd(ctx context.Context, p model.ChangePasswd, token string) error { + t, err := utilsService.Auth(ctx, yggdrasil.ValidateToken{AccessToken: token}, w.client, w.cache, &w.prikey.PublicKey, false) + if err != nil { + return fmt.Errorf("ChangePasswd: %w", err) + } + u, err := w.client.User.Query().Where(user.IDEQ(t.UID)).First(ctx) + if err != nil { + return fmt.Errorf("ChangePasswd: %w", err) + } + if !utils.Argon2Compare(p.Old, u.Password, u.Salt) { + return fmt.Errorf("ChangePasswd: %w", ErrPassWord) + } + pass, salt := utils.Argon2ID(p.New) + + err = w.client.User.UpdateOne(u).SetPassword(pass).SetSalt(salt).Exec(ctx) + if err != nil { + return fmt.Errorf("ChangePasswd: %w", err) + } + return nil +}