diff --git a/service/utils/auth.go b/service/utils/auth.go index e1b3596..1f9e291 100644 --- a/service/utils/auth.go +++ b/service/utils/auth.go @@ -79,3 +79,11 @@ func Auth(ctx context.Context, t yggdrasil.ValidateToken, client *ent.Client, c } return claims, nil } + +func IsAdmin(state int) bool { + return state&1 == 1 +} + +func IsDisable(state int) bool { + return state&2 == 2 +} diff --git a/service/yggdrasil/user.go b/service/yggdrasil/user.go index 9a8a659..6c6f087 100644 --- a/service/yggdrasil/user.go +++ b/service/yggdrasil/user.go @@ -30,9 +30,10 @@ import ( ) var ( - ErrRate = errors.New("频率限制") - ErrPassWord = errors.New("错误的密码或邮箱") - ErrNotUser = errors.New("没有这个用户") + ErrRate = errors.New("频率限制") + ErrPassWord = errors.New("错误的密码或邮箱") + ErrNotUser = errors.New("没有这个用户") + ErrUserDisable = errors.New("用户被禁用") ) func (y *Yggdrasil) validatePass(cxt context.Context, email, pass string) (*ent.User, error) { @@ -60,6 +61,10 @@ func (y *Yggdrasil) Authenticate(cxt context.Context, auth yggdrasil.Authenticat return yggdrasil.Token{}, fmt.Errorf("Authenticate: %w", err) } + if sutils.IsDisable(u.State) { + return yggdrasil.Token{}, fmt.Errorf("Authenticate: %w", ErrUserDisable) + } + clientToken := auth.ClientToken if clientToken == "" { clientToken = strings.ReplaceAll(uuid.New().String(), "-", "")