diff --git a/db/ent/migrate/schema.go b/db/ent/migrate/schema.go index 541e4f6..8abe3d3 100644 --- a/db/ent/migrate/schema.go +++ b/db/ent/migrate/schema.go @@ -97,9 +97,9 @@ var ( }, Indexes: []*schema.Index{ { - Name: "userprofile_name", + Name: "userprofile_uuid", Unique: false, - Columns: []*schema.Column{UserProfilesColumns[1]}, + Columns: []*schema.Column{UserProfilesColumns[2]}, }, }, } diff --git a/db/ent/schema/userprofile.go b/db/ent/schema/userprofile.go index 949806f..88efd4a 100644 --- a/db/ent/schema/userprofile.go +++ b/db/ent/schema/userprofile.go @@ -34,6 +34,6 @@ func (UserProfile) Edges() []ent.Edge { func (UserProfile) Indexes() []ent.Index { return []ent.Index{ - index.Fields("name"), + index.Fields("uuid"), } } diff --git a/handle/yggdrasil/user.go b/handle/yggdrasil/user.go index d0e723b..49a2674 100644 --- a/handle/yggdrasil/user.go +++ b/handle/yggdrasil/user.go @@ -96,3 +96,26 @@ func (y *Yggdrasil) Invalidate() httprouter.Handle { } } } + +func (y *Yggdrasil) Refresh() httprouter.Handle { + return func(w http.ResponseWriter, r *http.Request, p httprouter.Params) { + cxt := r.Context() + a, has := getAnyModel[yggdrasil.RefreshToken](cxt, w, r.Body, y.validate, y.logger) + if !has { + return + } + t, err := y.yggdrasilService.Refresh(cxt, a) + if err != nil { + if errors.Is(err, sutils.ErrTokenInvalid) { + y.logger.DebugContext(cxt, err.Error()) + handleYgError(cxt, w, yggdrasil.Error{ErrorMessage: "Invalid token.", Error: "ForbiddenOperationException"}, 403) + return + } + y.logger.WarnContext(cxt, err.Error()) + handleYgError(cxt, w, yggdrasil.Error{ErrorMessage: err.Error()}, 500) + return + } + b, _ := json.Marshal(t) + w.Write(b) + } +} diff --git a/model/yggdrasil/model.go b/model/yggdrasil/model.go index 93e7ee0..c12822f 100644 --- a/model/yggdrasil/model.go +++ b/model/yggdrasil/model.go @@ -1,7 +1,8 @@ package yggdrasil type Pass struct { - Username string `json:"username" validate:"required"` + // 目前只能是 email + Username string `json:"username" validate:"required,email"` Password string `json:"password" validate:"required"` } @@ -23,13 +24,14 @@ type Error struct { type Token struct { AccessToken string `json:"accessToken"` - AvailableProfiles []TokenProfile `json:"availableProfiles"` + AvailableProfiles []TokenProfile `json:"availableProfiles,omitempty"` ClientToken string `json:"clientToken"` - SelectedProfile TokenProfile `json:"selectedProfile"` + SelectedProfile TokenProfile `json:"selectedProfile,omitempty"` User TokenUser `json:"user,omitempty"` } type TokenProfile struct { + // 就是 uuid ID string `json:"id"` Name string `json:"name"` } @@ -43,3 +45,9 @@ type ValidateToken struct { AccessToken string `json:"accessToken" validate:"required,jwt"` ClientToken string `json:"clientToken"` } + +type RefreshToken struct { + ValidateToken + RequestUser bool `json:"requestUser"` + SelectedProfile TokenProfile `json:"selectedProfile"` +} diff --git a/service/utils/auth.go b/service/utils/auth.go index 32efeba..10aada7 100644 --- a/service/utils/auth.go +++ b/service/utils/auth.go @@ -18,7 +18,7 @@ var ( ErrTokenInvalid = errors.New("token 无效") ) -func Auth(ctx context.Context, t yggdrasil.ValidateToken, client *ent.Client, jwtKey string) (*model.TokenClaims, error) { +func Auth(ctx context.Context, t yggdrasil.ValidateToken, client *ent.Client, jwtKey string, tmpInvalid bool) (*model.TokenClaims, error) { token, err := jwt.ParseWithClaims(t.AccessToken, &model.TokenClaims{}, func(t *jwt.Token) (interface{}, error) { return []byte(jwtKey), nil }) @@ -34,17 +34,19 @@ func Auth(ctx context.Context, t yggdrasil.ValidateToken, client *ent.Client, jw return nil, fmt.Errorf("Auth: %w", ErrTokenInvalid) } - it, err := claims.GetIssuedAt() - if err != nil { - return nil, fmt.Errorf("Auth: %w", err) - } - et, err := claims.GetExpirationTime() - if err != nil { - return nil, fmt.Errorf("Auth: %w", err) - } - invalidTime := it.Add(et.Time.Sub(it.Time)) - if time.Now().After(invalidTime) { - return nil, fmt.Errorf("Auth: %w", ErrTokenInvalid) + if tmpInvalid { + it, err := claims.GetIssuedAt() + if err != nil { + return nil, fmt.Errorf("Auth: %w", err) + } + et, err := claims.GetExpirationTime() + if err != nil { + return nil, fmt.Errorf("Auth: %w", err) + } + invalidTime := it.Add(et.Time.Sub(it.Time) / 2) + if time.Now().After(invalidTime) { + return nil, fmt.Errorf("Auth: %w", ErrTokenInvalid) + } } ut, err := client.UserToken.Query().Where(usertoken.UUIDEQ(claims.Subject)).First(ctx) diff --git a/service/yggdrasil/user.go b/service/yggdrasil/user.go index c8f7c63..b184fa4 100644 --- a/service/yggdrasil/user.go +++ b/service/yggdrasil/user.go @@ -8,12 +8,10 @@ import ( "strings" "time" - "github.com/golang-jwt/jwt/v5" "github.com/google/uuid" "github.com/xmdhs/authlib-skin/db/ent" "github.com/xmdhs/authlib-skin/db/ent/user" "github.com/xmdhs/authlib-skin/db/ent/usertoken" - "github.com/xmdhs/authlib-skin/model" "github.com/xmdhs/authlib-skin/model/yggdrasil" sutils "github.com/xmdhs/authlib-skin/service/utils" "github.com/xmdhs/authlib-skin/utils" @@ -80,26 +78,15 @@ func (y *Yggdrasil) Authenticate(cxt context.Context, auth yggdrasil.Authenticat return yggdrasil.Token{}, fmt.Errorf("Authenticate: %w", err) } - claims := model.TokenClaims{ - Tid: strconv.FormatUint(utoken.TokenID, 10), - CID: clientToken, - RegisteredClaims: jwt.RegisteredClaims{ - ExpiresAt: jwt.NewNumericDate(time.Now().Add(15 * 24 * time.Hour)), - Issuer: "authlib-skin", - Subject: u.Edges.Profile.UUID, - IssuedAt: jwt.NewNumericDate(time.Now()), - }, - } - token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims) - jwts, err := token.SignedString([]byte(y.config.JwtKey)) + jwts, err := newJwtToken(y.config.JwtKey, strconv.FormatUint(utoken.TokenID, 10), clientToken, u.Edges.Profile.UUID) if err != nil { return yggdrasil.Token{}, fmt.Errorf("Authenticate: %w", err) } + p := yggdrasil.TokenProfile{ ID: u.Edges.Profile.UUID, Name: u.Edges.Profile.Name, } - return yggdrasil.Token{ AccessToken: jwts, AvailableProfiles: []yggdrasil.TokenProfile{p}, @@ -113,7 +100,7 @@ func (y *Yggdrasil) Authenticate(cxt context.Context, auth yggdrasil.Authenticat } func (y *Yggdrasil) ValidateToken(ctx context.Context, t yggdrasil.ValidateToken) error { - _, err := sutils.Auth(ctx, t, y.client, y.config.JwtKey) + _, err := sutils.Auth(ctx, t, y.client, y.config.JwtKey, true) if err != nil { return fmt.Errorf("ValidateToken: %w", err) } @@ -141,7 +128,7 @@ func (y *Yggdrasil) SignOut(ctx context.Context, t yggdrasil.Pass) error { } func (y *Yggdrasil) Invalidate(ctx context.Context, accessToken string) error { - t, err := sutils.Auth(ctx, yggdrasil.ValidateToken{AccessToken: accessToken}, y.client, y.config.JwtKey) + t, err := sutils.Auth(ctx, yggdrasil.ValidateToken{AccessToken: accessToken}, y.client, y.config.JwtKey, true) if err != nil { return fmt.Errorf("Invalidate: %w", err) } @@ -151,3 +138,21 @@ func (y *Yggdrasil) Invalidate(ctx context.Context, accessToken string) error { } return nil } + +func (y *Yggdrasil) Refresh(ctx context.Context, token yggdrasil.RefreshToken) (yggdrasil.Token, error) { + t, err := sutils.Auth(ctx, yggdrasil.ValidateToken{AccessToken: token.AccessToken, ClientToken: token.ClientToken}, y.client, y.config.JwtKey, false) + if err != nil { + return yggdrasil.Token{}, fmt.Errorf("Refresh: %w", err) + } + jwts, err := newJwtToken(y.config.JwtKey, t.Tid, t.CID, t.Subject) + if err != nil { + return yggdrasil.Token{}, fmt.Errorf("Authenticate: %w", err) + } + return yggdrasil.Token{ + AccessToken: jwts, + ClientToken: t.CID, + User: yggdrasil.TokenUser{ + ID: t.Subject, + }, + }, nil +} diff --git a/service/yggdrasil/yggdrasil.go b/service/yggdrasil/yggdrasil.go index a128027..a355989 100644 --- a/service/yggdrasil/yggdrasil.go +++ b/service/yggdrasil/yggdrasil.go @@ -5,9 +5,11 @@ import ( "fmt" "time" + "github.com/golang-jwt/jwt/v5" "github.com/xmdhs/authlib-skin/config" "github.com/xmdhs/authlib-skin/db/cache" "github.com/xmdhs/authlib-skin/db/ent" + "github.com/xmdhs/authlib-skin/model" ) type Yggdrasil struct { @@ -57,3 +59,22 @@ func putUint(n uint64, c cache.Cache, key []byte, d time.Duration) error { } return nil } + +func newJwtToken(jwtKey string, tokenID, clientToken, UUID string) (string, error) { + claims := model.TokenClaims{ + Tid: tokenID, + CID: clientToken, + RegisteredClaims: jwt.RegisteredClaims{ + ExpiresAt: jwt.NewNumericDate(time.Now().Add(15 * 24 * time.Hour)), + Issuer: "authlib-skin", + Subject: UUID, + IssuedAt: jwt.NewNumericDate(time.Now()), + }, + } + token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims) + jwts, err := token.SignedString([]byte(jwtKey)) + if err != nil { + return "", fmt.Errorf("newJwtToken: %w", err) + } + return jwts, nil +}