2023-09-04 15:41:19 +08:00

59 lines
1.4 KiB
Go

package utils
import (
"context"
"errors"
"fmt"
"strconv"
"time"
"github.com/golang-jwt/jwt/v5"
"github.com/xmdhs/authlib-skin/db/ent"
"github.com/xmdhs/authlib-skin/db/ent/usertoken"
"github.com/xmdhs/authlib-skin/model"
"github.com/xmdhs/authlib-skin/model/yggdrasil"
)
var (
ErrTokenInvalid = errors.New("token 无效")
)
func Auth(ctx context.Context, t yggdrasil.ValidateToken, client *ent.Client, jwtKey string) error {
token, err := jwt.ParseWithClaims(t.AccessToken, &model.TokenClaims{}, func(t *jwt.Token) (interface{}, error) {
return []byte(jwtKey), nil
})
if err != nil {
return fmt.Errorf("Auth: %w", err)
}
claims, ok := token.Claims.(*model.TokenClaims)
if !ok || !token.Valid {
return fmt.Errorf("Auth: %w", ErrTokenInvalid)
}
if t.ClientToken != "" && t.ClientToken != claims.CID {
return fmt.Errorf("Auth: %w", ErrTokenInvalid)
}
it, err := claims.GetIssuedAt()
if err != nil {
return fmt.Errorf("Auth: %w", err)
}
et, err := claims.GetExpirationTime()
if err != nil {
return fmt.Errorf("Auth: %w", err)
}
invalidTime := it.Add(et.Time.Sub(it.Time))
if time.Now().After(invalidTime) {
return fmt.Errorf("Auth: %w", ErrTokenInvalid)
}
ut, err := client.UserToken.Query().Where(usertoken.UUIDEQ(claims.Subject)).First(ctx)
if err != nil {
return fmt.Errorf("Auth: %w", err)
}
if strconv.FormatUint(ut.TokenID, 10) != claims.Tid {
return fmt.Errorf("Auth: %w", ErrTokenInvalid)
}
return nil
}