diff --git a/handle/yggdrasil/user.go b/handle/yggdrasil/user.go index 6a4535e..d0e723b 100644 --- a/handle/yggdrasil/user.go +++ b/handle/yggdrasil/user.go @@ -9,16 +9,13 @@ import ( "github.com/xmdhs/authlib-skin/model/yggdrasil" sutils "github.com/xmdhs/authlib-skin/service/utils" yggdrasilS "github.com/xmdhs/authlib-skin/service/yggdrasil" - "github.com/xmdhs/authlib-skin/utils" ) func (y *Yggdrasil) Authenticate() httprouter.Handle { return func(w http.ResponseWriter, r *http.Request, p httprouter.Params) { cxt := r.Context() - a, err := utils.DeCodeBody[yggdrasil.Authenticate](r.Body, y.validate) - if err != nil { - y.logger.DebugContext(cxt, err.Error()) - handleYgError(cxt, w, yggdrasil.Error{ErrorMessage: err.Error()}, 400) + a, has := getAnyModel[yggdrasil.Authenticate](cxt, w, r.Body, y.validate, y.logger) + if !has { return } t, err := y.yggdrasilService.Authenticate(cxt, a) @@ -40,13 +37,11 @@ func (y *Yggdrasil) Authenticate() httprouter.Handle { func (y *Yggdrasil) Validate() httprouter.Handle { return func(w http.ResponseWriter, r *http.Request, p httprouter.Params) { cxt := r.Context() - a, err := utils.DeCodeBody[yggdrasil.ValidateToken](r.Body, y.validate) - if err != nil { - y.logger.DebugContext(cxt, err.Error()) - handleYgError(cxt, w, yggdrasil.Error{ErrorMessage: err.Error()}, 400) + a, has := getAnyModel[yggdrasil.ValidateToken](cxt, w, r.Body, y.validate, y.logger) + if !has { return } - err = y.yggdrasilService.ValidateToken(cxt, a) + err := y.yggdrasilService.ValidateToken(cxt, a) if err != nil { if errors.Is(err, sutils.ErrTokenInvalid) { y.logger.DebugContext(cxt, err.Error()) @@ -60,3 +55,44 @@ func (y *Yggdrasil) Validate() httprouter.Handle { w.WriteHeader(204) } } + +func (y *Yggdrasil) Signout() httprouter.Handle { + return func(w http.ResponseWriter, r *http.Request, p httprouter.Params) { + cxt := r.Context() + a, has := getAnyModel[yggdrasil.Pass](cxt, w, r.Body, y.validate, y.logger) + if !has { + return + } + err := y.yggdrasilService.SignOut(cxt, a) + if err != nil { + if errors.Is(err, yggdrasilS.ErrPassWord) || errors.Is(err, yggdrasilS.ErrRate) { + y.logger.DebugContext(cxt, err.Error()) + handleYgError(cxt, w, yggdrasil.Error{ErrorMessage: "Invalid credentials. Invalid username or password.", Error: "ForbiddenOperationException"}, 403) + return + } + y.logger.WarnContext(cxt, err.Error()) + handleYgError(cxt, w, yggdrasil.Error{ErrorMessage: err.Error()}, 500) + return + } + w.WriteHeader(204) + } +} + +func (y *Yggdrasil) Invalidate() httprouter.Handle { + return func(w http.ResponseWriter, r *http.Request, p httprouter.Params) { + w.WriteHeader(204) + cxt := r.Context() + a, has := getAnyModel[yggdrasil.ValidateToken](cxt, w, r.Body, y.validate, y.logger) + if !has { + return + } + err := y.yggdrasilService.Invalidate(cxt, a.AccessToken) + if err != nil { + if errors.Is(err, sutils.ErrTokenInvalid) { + y.logger.DebugContext(cxt, err.Error()) + return + } + y.logger.WarnContext(cxt, err.Error()) + } + } +} diff --git a/handle/yggdrasil/yggdrasil.go b/handle/yggdrasil/yggdrasil.go index 8449c36..233ddd3 100644 --- a/handle/yggdrasil/yggdrasil.go +++ b/handle/yggdrasil/yggdrasil.go @@ -1,10 +1,15 @@ package yggdrasil import ( + "context" + "io" "log/slog" + "net/http" "github.com/go-playground/validator/v10" + "github.com/xmdhs/authlib-skin/model/yggdrasil" yggdrasilS "github.com/xmdhs/authlib-skin/service/yggdrasil" + "github.com/xmdhs/authlib-skin/utils" ) type Yggdrasil struct { @@ -20,3 +25,13 @@ func NewYggdrasil(logger *slog.Logger, validate *validator.Validate, yggdrasilSe yggdrasilService: yggdrasilService, } } + +func getAnyModel[K any](ctx context.Context, w http.ResponseWriter, r io.Reader, validate *validator.Validate, slog *slog.Logger) (K, bool) { + a, err := utils.DeCodeBody[K](r, validate) + if err != nil { + slog.DebugContext(ctx, err.Error()) + handleYgError(ctx, w, yggdrasil.Error{ErrorMessage: err.Error()}, 400) + return a, false + } + return a, true +} diff --git a/model/yggdrasil/model.go b/model/yggdrasil/model.go index 899dbeb..93e7ee0 100644 --- a/model/yggdrasil/model.go +++ b/model/yggdrasil/model.go @@ -1,14 +1,18 @@ package yggdrasil +type Pass struct { + Username string `json:"username" validate:"required"` + Password string `json:"password" validate:"required"` +} + type Authenticate struct { Agent struct { Name string `json:"name" validate:"required,eq=Minecraft"` Version int `json:"version" validate:"required,eq=1"` } `json:"agent"` ClientToken string `json:"clientToken"` - Password string `json:"password" validate:"required"` RequestUser bool `json:"requestUser"` - Username string `json:"username" validate:"required"` + Pass } type Error struct { diff --git a/server/route/route.go b/server/route/route.go index 794a289..360a87d 100644 --- a/server/route/route.go +++ b/server/route/route.go @@ -24,6 +24,9 @@ func NewRoute(yggService *yggdrasil.Yggdrasil, handel *handle.Handel) (*httprout func newYggdrasil(r *httprouter.Router, handelY yggdrasil.Yggdrasil) error { r.POST("/api/authserver/authenticate", warpHJSON(handelY.Authenticate())) r.POST("/api/authserver/validate", warpHJSON(handelY.Validate())) + r.POST("/api/authserver/signout", warpHJSON(handelY.Signout())) + r.POST("/api/authserver/invalidate", handelY.Invalidate()) + // TODO /authserver/refresh return nil } diff --git a/service/utils/auth.go b/service/utils/auth.go index 4337eb0..32efeba 100644 --- a/service/utils/auth.go +++ b/service/utils/auth.go @@ -18,41 +18,41 @@ var ( ErrTokenInvalid = errors.New("token 无效") ) -func Auth(ctx context.Context, t yggdrasil.ValidateToken, client *ent.Client, jwtKey string) error { +func Auth(ctx context.Context, t yggdrasil.ValidateToken, client *ent.Client, jwtKey string) (*model.TokenClaims, error) { token, err := jwt.ParseWithClaims(t.AccessToken, &model.TokenClaims{}, func(t *jwt.Token) (interface{}, error) { return []byte(jwtKey), nil }) if err != nil { - return fmt.Errorf("Auth: %w", err) + return nil, fmt.Errorf("Auth: %w", err) } claims, ok := token.Claims.(*model.TokenClaims) if !ok || !token.Valid { - return fmt.Errorf("Auth: %w", ErrTokenInvalid) + return nil, fmt.Errorf("Auth: %w", ErrTokenInvalid) } if t.ClientToken != "" && t.ClientToken != claims.CID { - return fmt.Errorf("Auth: %w", ErrTokenInvalid) + return nil, fmt.Errorf("Auth: %w", ErrTokenInvalid) } it, err := claims.GetIssuedAt() if err != nil { - return fmt.Errorf("Auth: %w", err) + return nil, fmt.Errorf("Auth: %w", err) } et, err := claims.GetExpirationTime() if err != nil { - return fmt.Errorf("Auth: %w", err) + return nil, fmt.Errorf("Auth: %w", err) } invalidTime := it.Add(et.Time.Sub(it.Time)) if time.Now().After(invalidTime) { - return fmt.Errorf("Auth: %w", ErrTokenInvalid) + return nil, fmt.Errorf("Auth: %w", ErrTokenInvalid) } ut, err := client.UserToken.Query().Where(usertoken.UUIDEQ(claims.Subject)).First(ctx) if err != nil { - return fmt.Errorf("Auth: %w", err) + return nil, fmt.Errorf("Auth: %w", err) } if strconv.FormatUint(ut.TokenID, 10) != claims.Tid { - return fmt.Errorf("Auth: %w", ErrTokenInvalid) + return nil, fmt.Errorf("Auth: %w", ErrTokenInvalid) } - return nil + return claims, nil } diff --git a/service/yggdrasil/authenticate.go b/service/yggdrasil/user.go similarity index 57% rename from service/yggdrasil/authenticate.go rename to service/yggdrasil/user.go index 91e35ba..c8f7c63 100644 --- a/service/yggdrasil/authenticate.go +++ b/service/yggdrasil/user.go @@ -2,7 +2,6 @@ package yggdrasil import ( "context" - "encoding/binary" "errors" "fmt" "strconv" @@ -11,11 +10,12 @@ import ( "github.com/golang-jwt/jwt/v5" "github.com/google/uuid" - "github.com/xmdhs/authlib-skin/db/cache" "github.com/xmdhs/authlib-skin/db/ent" "github.com/xmdhs/authlib-skin/db/ent/user" + "github.com/xmdhs/authlib-skin/db/ent/usertoken" "github.com/xmdhs/authlib-skin/model" "github.com/xmdhs/authlib-skin/model/yggdrasil" + sutils "github.com/xmdhs/authlib-skin/service/utils" "github.com/xmdhs/authlib-skin/utils" ) @@ -24,23 +24,31 @@ var ( ErrPassWord = errors.New("错误的密码或邮箱") ) +func (y *Yggdrasil) validatePass(cxt context.Context, email, pass string) (*ent.User, error) { + err := rate("validatePass"+email, y.cache, 10*time.Second, 3) + if err != nil { + return nil, fmt.Errorf("validatePass: %w", err) + } + u, err := y.client.User.Query().Where(user.EmailEQ(email)).WithProfile().First(cxt) + if err != nil { + var nf *ent.NotFoundError + if errors.As(err, &nf) { + return nil, fmt.Errorf("validatePass: %w", ErrPassWord) + } + return nil, fmt.Errorf("validatePass: %w", err) + } + if !utils.Argon2Compare(pass, u.Password, u.Salt) { + return nil, fmt.Errorf("validatePass: %w", ErrPassWord) + } + return u, nil +} + func (y *Yggdrasil) Authenticate(cxt context.Context, auth yggdrasil.Authenticate) (yggdrasil.Token, error) { - err := rate("Authenticate"+auth.Username, y.cache, 10*time.Second, 3) + u, err := y.validatePass(cxt, auth.Username, auth.Password) if err != nil { return yggdrasil.Token{}, fmt.Errorf("Authenticate: %w", err) } - u, err := y.client.User.Query().Where(user.EmailEQ(auth.Username)).WithProfile().First(cxt) - if err != nil { - var nf *ent.NotFoundError - if errors.As(err, &nf) { - return yggdrasil.Token{}, fmt.Errorf("Authenticate: %w", ErrPassWord) - } - return yggdrasil.Token{}, fmt.Errorf("Authenticate: %w", err) - } - if !utils.Argon2Compare(auth.Password, u.Password, u.Salt) { - return yggdrasil.Token{}, fmt.Errorf("Authenticate: %w", ErrPassWord) - } clientToken := auth.ClientToken if clientToken == "" { clientToken = strings.ReplaceAll(uuid.New().String(), "-", "") @@ -104,36 +112,42 @@ func (y *Yggdrasil) Authenticate(cxt context.Context, auth yggdrasil.Authenticat }, nil } -func rate(k string, c cache.Cache, d time.Duration, count uint) error { - key := []byte(k) - v, err := c.Get([]byte(key)) +func (y *Yggdrasil) ValidateToken(ctx context.Context, t yggdrasil.ValidateToken) error { + _, err := sutils.Auth(ctx, t, y.client, y.config.JwtKey) if err != nil { - return fmt.Errorf("rate: %w", err) - } - if v == nil { - err := putUint(1, c, key, d) - if err != nil { - return fmt.Errorf("rate: %w", err) - } - return nil - } - n := binary.BigEndian.Uint64(v) - if n > uint64(count) { - return fmt.Errorf("rate: %w", ErrRate) - } - err = putUint(n+1, c, key, d) - if err != nil { - return fmt.Errorf("rate: %w", err) + return fmt.Errorf("ValidateToken: %w", err) } return nil } -func putUint(n uint64, c cache.Cache, key []byte, d time.Duration) error { - b := make([]byte, 8) - binary.BigEndian.PutUint64(b, n) - err := c.Put(key, b, time.Now().Add(d)) +func (y *Yggdrasil) SignOut(ctx context.Context, t yggdrasil.Pass) error { + u, err := y.validatePass(ctx, t.Username, t.Password) if err != nil { - return fmt.Errorf("rate: %w", err) + return fmt.Errorf("SignOut: %w", err) + } + ut, err := y.client.UserToken.Query().Where(usertoken.UUIDEQ(u.Edges.Profile.UUID)).First(ctx) + if err != nil { + var nf *ent.NotFoundError + if !errors.As(err, &nf) { + return fmt.Errorf("SignOut: %w", err) + } + return nil + } + err = y.client.UserToken.UpdateOne(ut).AddTokenID(1).Exec(ctx) + if err != nil { + return fmt.Errorf("SignOut: %w", err) + } + return nil +} + +func (y *Yggdrasil) Invalidate(ctx context.Context, accessToken string) error { + t, err := sutils.Auth(ctx, yggdrasil.ValidateToken{AccessToken: accessToken}, y.client, y.config.JwtKey) + if err != nil { + return fmt.Errorf("Invalidate: %w", err) + } + err = y.client.UserToken.Update().Where(usertoken.UUIDEQ(t.Subject)).AddTokenID(1).Exec(ctx) + if err != nil { + return fmt.Errorf("Invalidate: %w", err) } return nil } diff --git a/service/yggdrasil/validate.go b/service/yggdrasil/validate.go deleted file mode 100644 index 49fce96..0000000 --- a/service/yggdrasil/validate.go +++ /dev/null @@ -1,17 +0,0 @@ -package yggdrasil - -import ( - "context" - "fmt" - - "github.com/xmdhs/authlib-skin/model/yggdrasil" - "github.com/xmdhs/authlib-skin/service/utils" -) - -func (y *Yggdrasil) ValidateToken(ctx context.Context, t yggdrasil.ValidateToken) error { - err := utils.Auth(ctx, t, y.client, y.config.JwtKey) - if err != nil { - return fmt.Errorf("ValidateToken: %w", err) - } - return nil -} diff --git a/service/yggdrasil/yggdrasil.go b/service/yggdrasil/yggdrasil.go index 8d4798d..a128027 100644 --- a/service/yggdrasil/yggdrasil.go +++ b/service/yggdrasil/yggdrasil.go @@ -1,6 +1,10 @@ package yggdrasil import ( + "encoding/binary" + "fmt" + "time" + "github.com/xmdhs/authlib-skin/config" "github.com/xmdhs/authlib-skin/db/cache" "github.com/xmdhs/authlib-skin/db/ent" @@ -19,3 +23,37 @@ func NewYggdrasil(client *ent.Client, cache cache.Cache, c config.Config) *Yggdr config: c, } } + +func rate(k string, c cache.Cache, d time.Duration, count uint) error { + key := []byte(k) + v, err := c.Get([]byte(key)) + if err != nil { + return fmt.Errorf("rate: %w", err) + } + if v == nil { + err := putUint(1, c, key, d) + if err != nil { + return fmt.Errorf("rate: %w", err) + } + return nil + } + n := binary.BigEndian.Uint64(v) + if n > uint64(count) { + return fmt.Errorf("rate: %w", ErrRate) + } + err = putUint(n+1, c, key, d) + if err != nil { + return fmt.Errorf("rate: %w", err) + } + return nil +} + +func putUint(n uint64, c cache.Cache, key []byte, d time.Duration) error { + b := make([]byte, 8) + binary.BigEndian.PutUint64(b, n) + err := c.Put(key, b, time.Now().Add(d)) + if err != nil { + return fmt.Errorf("rate: %w", err) + } + return nil +}