diff --git a/config/config.go b/config/config.go index 1ee3b6d..d685d38 100644 --- a/config/config.go +++ b/config/config.go @@ -42,8 +42,9 @@ type Captcha struct { } type EmailConfig struct { - Enable bool `toml:"enable" comment:"注册验证邮件,且允许使用邮箱找回账号"` - Smtp []SmtpUser `toml:"smtp"` + Enable bool `toml:"enable" comment:"注册验证邮件,且允许使用邮箱找回账号"` + Smtp []SmtpUser `toml:"smtp"` + AllowDomain []string `toml:"allow_domain" comment:"允许用于注册的邮箱域名,留空则允许全部"` } type SmtpUser struct { diff --git a/go.mod b/go.mod index 1b9fae0..764bc13 100644 --- a/go.mod +++ b/go.mod @@ -18,6 +18,7 @@ require ( github.com/redis/go-redis/v9 v9.2.1 github.com/samber/lo v1.38.1 github.com/stretchr/testify v1.8.4 + github.com/wneessen/go-mail v0.4.0 golang.org/x/crypto v0.14.0 ) @@ -38,7 +39,6 @@ require ( github.com/leodido/go-urn v1.2.4 // indirect github.com/mitchellh/go-wordwrap v0.0.0-20150314170334-ad45545899c7 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect - github.com/wneessen/go-mail v0.4.0 // indirect github.com/zclconf/go-cty v1.8.0 // indirect golang.org/x/exp v0.0.0-20220303212507-bbda1eaf7a17 // indirect golang.org/x/mod v0.10.0 // indirect diff --git a/handle/handelerror/error.go b/handle/handelerror/error.go index 85dcdd6..d63ee00 100644 --- a/handle/handelerror/error.go +++ b/handle/handelerror/error.go @@ -11,6 +11,7 @@ import ( "github.com/xmdhs/authlib-skin/service" "github.com/xmdhs/authlib-skin/service/auth" "github.com/xmdhs/authlib-skin/service/captcha" + "github.com/xmdhs/authlib-skin/service/email" ) type HandleError struct { @@ -23,38 +24,32 @@ func NewHandleError(logger *slog.Logger) *HandleError { } } +type errorHandler struct { + ErrorType error + ModelError model.APIStatus + StatusCode int + LogLevel slog.Level +} + +var errorHandlers = []errorHandler{ + {service.ErrExistUser, model.ErrExistUser, 400, slog.LevelDebug}, + {service.ErrExitsName, model.ErrExitsName, 400, slog.LevelDebug}, + {service.ErrRegLimit, model.ErrRegLimit, 400, slog.LevelInfo}, + {captcha.ErrCaptcha, model.ErrCaptcha, 400, slog.LevelDebug}, + {service.ErrPassWord, model.ErrPassWord, 401, slog.LevelInfo}, + {auth.ErrUserDisable, model.ErrUserDisable, 401, slog.LevelDebug}, + {service.ErrNotAdmin, model.ErrNotAdmin, 401, slog.LevelDebug}, + {auth.ErrTokenInvalid, model.ErrAuth, 401, slog.LevelDebug}, + {email.ErrTokenInvalid, model.ErrAuth, 401, slog.LevelDebug}, + {email.ErrSendLimit, model.ErrEmailSend, 403, slog.LevelDebug}, +} + func (h *HandleError) Service(ctx context.Context, w http.ResponseWriter, err error) { - if errors.Is(err, service.ErrExistUser) { - h.Error(ctx, w, err.Error(), model.ErrExistUser, 400, slog.LevelDebug) - return - } - if errors.Is(err, service.ErrExitsName) { - h.Error(ctx, w, err.Error(), model.ErrExitsName, 400, slog.LevelDebug) - return - } - if errors.Is(err, service.ErrRegLimit) { - h.Error(ctx, w, err.Error(), model.ErrRegLimit, 400, slog.LevelDebug) - return - } - if errors.Is(err, captcha.ErrCaptcha) { - h.Error(ctx, w, err.Error(), model.ErrCaptcha, 400, slog.LevelDebug) - return - } - if errors.Is(err, service.ErrPassWord) { - h.Error(ctx, w, err.Error(), model.ErrPassWord, 401, slog.LevelDebug) - return - } - if errors.Is(err, auth.ErrUserDisable) { - h.Error(ctx, w, err.Error(), model.ErrUserDisable, 401, slog.LevelDebug) - return - } - if errors.Is(err, service.ErrNotAdmin) { - h.Error(ctx, w, err.Error(), model.ErrNotAdmin, 401, slog.LevelDebug) - return - } - if errors.Is(err, auth.ErrTokenInvalid) { - h.Error(ctx, w, err.Error(), model.ErrAuth, 401, slog.LevelDebug) - return + for _, errorHandler := range errorHandlers { + if errors.Is(err, errorHandler.ErrorType) { + h.Error(ctx, w, err.Error(), errorHandler.ModelError, errorHandler.StatusCode, errorHandler.LogLevel) + return + } } h.Error(ctx, w, err.Error(), model.ErrService, 500, slog.LevelWarn) diff --git a/handle/user.go b/handle/user.go index 975decb..e03e9cc 100644 --- a/handle/user.go +++ b/handle/user.go @@ -19,13 +19,13 @@ import ( type UserHandel struct { handleError *handelerror.HandleError validate *validator.Validate - userService *service.UserSerice + userService *service.UserService logger *slog.Logger textureService *service.TextureService } func NewUserHandel(handleError *handelerror.HandleError, validate *validator.Validate, - userService *service.UserSerice, logger *slog.Logger, textureService *service.TextureService) *UserHandel { + userService *service.UserService, logger *slog.Logger, textureService *service.TextureService) *UserHandel { return &UserHandel{ handleError: handleError, validate: validate, @@ -212,3 +212,28 @@ func (h *UserHandel) PutTexture() http.HandlerFunc { }) } } + +func (h *UserHandel) SendRegEmail() http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() + c, err := utils.DeCodeBody[model.SendRegEmail](r.Body, h.validate) + if err != nil { + h.handleError.Error(ctx, w, err.Error(), model.ErrInput, 400, slog.LevelDebug) + return + } + ip, err := utils.GetIP(r) + if err != nil { + h.handleError.Error(ctx, w, err.Error(), model.ErrInput, 400, slog.LevelDebug) + return + } + + err = h.userService.SendRegEmail(ctx, c.Email, c.CaptchaToken, r.Host, ip) + if err != nil { + h.handleError.Service(ctx, w, err) + return + } + encodeJson(w, model.API[any]{ + Code: 0, + }) + } +} diff --git a/model/const.go b/model/const.go index b72c486..19711b6 100644 --- a/model/const.go +++ b/model/const.go @@ -15,4 +15,5 @@ const ( ErrNotAdmin ErrUserDisable ErrCaptcha + ErrEmailSend ) diff --git a/model/model.go b/model/model.go index 3c94a65..58cdce2 100644 --- a/model/model.go +++ b/model/model.go @@ -18,6 +18,7 @@ type UserReg struct { Password string `validate:"required,min=6,max=50"` Name string `validate:"required,min=3,max=16"` CaptchaToken string + EmailJwt string } type TokenClaims struct { @@ -84,3 +85,8 @@ type LoginRep struct { Name string `json:"name"` UUID string `json:"uuid"` } + +type SendRegEmail struct { + Email string `json:"email"` + CaptchaToken string `json:"captchaToken"` +} diff --git a/service/email/email.go b/service/email/email.go index a9630e2..1fb2e27 100644 --- a/service/email/email.go +++ b/service/email/email.go @@ -26,14 +26,14 @@ type EmailConfig struct { Pass string } -type Email struct { +type EmailService struct { emailConfig []EmailConfig pri *rsa.PrivateKey config config.Config cache cache.Cache } -func NewEmail(pri *rsa.PrivateKey, c config.Config, cache cache.Cache) (*Email, error) { +func NewEmail(pri *rsa.PrivateKey, c config.Config, cache cache.Cache) (*EmailService, error) { ec := lo.Map[config.SmtpUser, EmailConfig](c.Email.Smtp, func(item config.SmtpUser, index int) EmailConfig { return EmailConfig{ Host: item.Host, @@ -44,7 +44,7 @@ func NewEmail(pri *rsa.PrivateKey, c config.Config, cache cache.Cache) (*Email, } }) - return &Email{ + return &EmailService{ emailConfig: ec, pri: pri, config: c, @@ -52,12 +52,12 @@ func NewEmail(pri *rsa.PrivateKey, c config.Config, cache cache.Cache) (*Email, }, nil } -func (e Email) getRandEmailUser() EmailConfig { +func (e EmailService) getRandEmailUser() EmailConfig { i := rand.Intn(len(e.emailConfig)) return e.emailConfig[i] } -func (e Email) SendEmail(ctx context.Context, to string, subject, body string) error { +func (e EmailService) SendEmail(ctx context.Context, to string, subject, body string) error { u := e.getRandEmailUser() m := mail.NewMsg() @@ -92,7 +92,7 @@ func (e Email) SendEmail(ctx context.Context, to string, subject, body string) e var emailTemplate = lo.Must(template.New("email").Parse(`
{{ .msg }}
{{ .url }}`)) -func (e Email) SendVerifyUrl(ctx context.Context, email string, interval int, host string) error { +func (e EmailService) SendVerifyUrl(ctx context.Context, email string, interval int, host string) error { sendKey := []byte("SendEmail" + email) sendB, err := e.cache.Get(sendKey) if err != nil { @@ -143,12 +143,11 @@ func (e Email) SendVerifyUrl(ctx context.Context, email string, interval int, ho } var ( - ErrCodeNotValid = errors.New("验证码无效") ErrSendLimit = errors.New("邮件发送限制") ErrTokenInvalid = errors.New("token 无效") ) -func (e Email) VerifyJwt(email, jwtStr string) error { +func (e EmailService) VerifyJwt(email, jwtStr string) error { token, err := jwt.ParseWithClaims(jwtStr, &jwt.RegisteredClaims{}, func(t *jwt.Token) (interface{}, error) { return e.pri.PublicKey, nil }) diff --git a/service/user.go b/service/user.go index f242af3..423721b 100644 --- a/service/user.go +++ b/service/user.go @@ -17,6 +17,7 @@ import ( "github.com/xmdhs/authlib-skin/model" "github.com/xmdhs/authlib-skin/service/auth" "github.com/xmdhs/authlib-skin/service/captcha" + "github.com/xmdhs/authlib-skin/service/email" "github.com/xmdhs/authlib-skin/utils" ) @@ -28,26 +29,28 @@ var ( ErrChangeName = errors.New("离线模式 uuid 不允许修改用户名") ) -type UserSerice struct { +type UserService struct { config config.Config client *ent.Client captchaService *captcha.CaptchaService authService *auth.AuthService cache cache.Cache + emailService *email.EmailService } func NewUserSerice(config config.Config, client *ent.Client, captchaService *captcha.CaptchaService, - authService *auth.AuthService, cache cache.Cache) *UserSerice { - return &UserSerice{ + authService *auth.AuthService, cache cache.Cache, emailService *email.EmailService) *UserService { + return &UserService{ config: config, client: client, captchaService: captchaService, authService: authService, cache: cache, + emailService: emailService, } } -func (w *UserSerice) Reg(ctx context.Context, u model.UserReg, ipPrefix, ip string) (model.LoginRep, error) { +func (w *UserService) Reg(ctx context.Context, u model.UserReg, ipPrefix, ip string) (model.LoginRep, error) { var userUuid string if w.config.OfflineUUID { userUuid = utils.UUIDGen(u.Name) @@ -55,6 +58,13 @@ func (w *UserSerice) Reg(ctx context.Context, u model.UserReg, ipPrefix, ip stri userUuid = strings.ReplaceAll(uuid.New().String(), "-", "") } + if w.config.Email.Enable { + err := w.emailService.VerifyJwt(u.Email, u.EmailJwt) + if err != nil { + return model.LoginRep{}, fmt.Errorf("Reg: %w", err) + } + } + err := w.captchaService.VerifyCaptcha(ctx, u.CaptchaToken, ip) if err != nil { return model.LoginRep{}, fmt.Errorf("Reg: %w", err) @@ -130,7 +140,7 @@ func (w *UserSerice) Reg(ctx context.Context, u model.UserReg, ipPrefix, ip stri }, nil } -func (w *UserSerice) Login(ctx context.Context, l model.Login, ip string) (model.LoginRep, error) { +func (w *UserService) Login(ctx context.Context, l model.Login, ip string) (model.LoginRep, error) { err := w.captchaService.VerifyCaptcha(ctx, l.CaptchaToken, ip) if err != nil { return model.LoginRep{}, fmt.Errorf("Login: %w", err) @@ -158,7 +168,7 @@ func (w *UserSerice) Login(ctx context.Context, l model.Login, ip string) (model }, nil } -func (w *UserSerice) Info(ctx context.Context, t *model.TokenClaims) (model.UserInfo, error) { +func (w *UserService) Info(ctx context.Context, t *model.TokenClaims) (model.UserInfo, error) { u, err := w.client.User.Query().Where(user.ID(t.UID)).First(ctx) if err != nil { return model.UserInfo{}, fmt.Errorf("Info: %w", err) @@ -171,7 +181,7 @@ func (w *UserSerice) Info(ctx context.Context, t *model.TokenClaims) (model.User }, nil } -func (w *UserSerice) ChangePasswd(ctx context.Context, p model.ChangePasswd, t *model.TokenClaims) error { +func (w *UserService) ChangePasswd(ctx context.Context, p model.ChangePasswd, t *model.TokenClaims) error { u, err := w.client.User.Query().Where(user.IDEQ(t.UID)).WithToken().First(ctx) if err != nil { return fmt.Errorf("ChangePasswd: %w", err) @@ -198,7 +208,7 @@ func (w *UserSerice) ChangePasswd(ctx context.Context, p model.ChangePasswd, t * return nil } -func (w *UserSerice) changeName(ctx context.Context, newName string, uid int, uuid string) error { +func (w *UserService) changeName(ctx context.Context, newName string, uid int, uuid string) error { if w.config.OfflineUUID { return fmt.Errorf("changeName: %w", ErrChangeName) } @@ -217,10 +227,38 @@ func (w *UserSerice) changeName(ctx context.Context, newName string, uid int, uu return err } -func (w *UserSerice) ChangeName(ctx context.Context, newName string, t *model.TokenClaims) error { +func (w *UserService) ChangeName(ctx context.Context, newName string, t *model.TokenClaims) error { err := w.changeName(ctx, newName, t.UID, t.Subject) if err != nil { return fmt.Errorf("ChangeName: %w", err) } return nil } + +var ErrNotAllowDomain = errors.New("不在允许域名列表内") + +func (w *UserService) SendRegEmail(ctx context.Context, email, CaptchaToken, host, ip string) error { + if len(w.config.Email.AllowDomain) != 0 { + allow := false + for _, v := range w.config.Email.AllowDomain { + if strings.HasSuffix(email, v) { + allow = true + break + } + } + if !allow { + return fmt.Errorf("SendRegEmail: %w", ErrNotAllowDomain) + } + } + + err := w.captchaService.VerifyCaptcha(ctx, CaptchaToken, ip) + if err != nil { + return fmt.Errorf("SendRegEmail: %w", err) + } + + err = w.emailService.SendVerifyUrl(ctx, email, 60, host) + if err != nil { + return fmt.Errorf("SendRegEmail: %w", ErrNotAllowDomain) + } + return nil +} diff --git a/service/user_test.go b/service/user_test.go index 97f470f..73a5878 100644 --- a/service/user_test.go +++ b/service/user_test.go @@ -17,10 +17,11 @@ import ( "github.com/xmdhs/authlib-skin/model" "github.com/xmdhs/authlib-skin/service/auth" "github.com/xmdhs/authlib-skin/service/captcha" + "github.com/xmdhs/authlib-skin/service/email" ) var ( - userSerice *UserSerice + userSerice *UserService adminSerice *AdminService ) @@ -42,8 +43,9 @@ func initSerice(ctx context.Context) func() { cache := cache.NewFastCache(100000) config := config.Default() authService := auth.NewAuthService(c, cache, &rsa4.PublicKey, rsa4) + email := lo.Must(email.NewEmail(rsa4, config, cache)) - userSerice = NewUserSerice(config, c, captcha.NewCaptchaService(config, &http.Client{}), authService, cache) + userSerice = NewUserSerice(config, c, captcha.NewCaptchaService(config, &http.Client{}), authService, cache, email) adminSerice = NewAdminService(authService, c, config, cache) return func() { @@ -66,7 +68,7 @@ func TestUserSerice_Reg(t *testing.T) { } tests := []struct { name string - w *UserSerice + w *UserService args args wantErr bool }{