diff --git a/config/config.go b/config/config.go index d031b65..a651025 100644 --- a/config/config.go +++ b/config/config.go @@ -16,5 +16,6 @@ type Config struct { Type string Ram int } - RaelIP bool + RaelIP bool + MaxIpUser int } diff --git a/db/ent/schema/user.go b/db/ent/schema/user.go index e485feb..ed4608f 100644 --- a/db/ent/schema/user.go +++ b/db/ent/schema/user.go @@ -48,5 +48,6 @@ func (User) Edges() []ent.Edge { func (User) Indexes() []ent.Index { return []ent.Index{ index.Fields("email").Unique(), + index.Fields("reg_ip"), } } diff --git a/go.mod b/go.mod index c57d790..b5e3d3e 100644 --- a/go.mod +++ b/go.mod @@ -12,6 +12,7 @@ require ( github.com/google/uuid v1.3.1 github.com/google/wire v0.5.0 github.com/julienschmidt/httprouter v1.3.0 + github.com/samber/lo v1.38.1 golang.org/x/crypto v0.7.0 ) @@ -29,7 +30,6 @@ require ( github.com/hashicorp/hcl/v2 v2.13.0 // indirect github.com/leodido/go-urn v1.2.4 // indirect github.com/mitchellh/go-wordwrap v0.0.0-20150314170334-ad45545899c7 // indirect - github.com/samber/lo v1.38.1 // indirect github.com/zclconf/go-cty v1.8.0 // indirect golang.org/x/exp v0.0.0-20220303212507-bbda1eaf7a17 // indirect golang.org/x/mod v0.10.0 // indirect diff --git a/handle/user.go b/handle/user.go index 73890a5..ef51419 100644 --- a/handle/user.go +++ b/handle/user.go @@ -36,6 +36,11 @@ func (h *Handel) Reg() httprouter.Handle { handleError(ctx, w, err.Error(), model.ErrExistUser, 400) return } + if errors.Is(err, service.ErrRegLimit) { + h.logger.DebugContext(ctx, err.Error()) + handleError(ctx, w, err.Error(), model.ErrRegLimit, 400) + return + } h.logger.WarnContext(ctx, err.Error()) handleError(ctx, w, err.Error(), model.ErrService, 500) return @@ -55,5 +60,5 @@ func getPrefix(r *http.Request, fromHeader bool) (string, error) { if ipa.Is6() { return lo.Must1(ipa.Prefix(48)).String(), nil } - return ipa.String(), nil + return lo.Must1(ipa.Prefix(24)).String(), nil } diff --git a/model/const.go b/model/const.go index 6db5156..6574d29 100644 --- a/model/const.go +++ b/model/const.go @@ -8,4 +8,5 @@ const ( ErrInput ErrService ErrExistUser + ErrRegLimit ) diff --git a/service/user.go b/service/user.go index 105161b..fa282de 100644 --- a/service/user.go +++ b/service/user.go @@ -20,6 +20,7 @@ import ( var ( ErrExistUser = errors.New("邮箱已存在") ErrExitsName = errors.New("用户名已存在") + ErrRegLimit = errors.New("超过注册 ip 限制") ) func (w *WebService) Reg(ctx context.Context, u model.User, ip string) error { @@ -31,6 +32,16 @@ func (w *WebService) Reg(ctx context.Context, u model.User, ip string) error { } p, s := utils.Argon2ID(u.Password) + if w.config.MaxIpUser != 0 { + c, err := w.client.User.Query().Where(user.RegIPEQ(ip)).Count(ctx) + if err != nil { + return fmt.Errorf("Reg: %w", err) + } + if c >= w.config.MaxIpUser { + return fmt.Errorf("Reg: %w", ErrRegLimit) + } + } + err := utils.WithTx(ctx, w.client, func(tx *ent.Tx) error { count, err := tx.User.Query().Where(user.EmailEQ(u.Email)).ForUpdate().Count(ctx) if err != nil { @@ -51,7 +62,7 @@ func (w *WebService) Reg(ctx context.Context, u model.User, ip string) error { SetPassword(p). SetSalt(s). SetRegTime(time.Now().Unix()). - SetRegIP(""). + SetRegIP(ip). SetState(0).Save(ctx) if err != nil { return err