From 083d8598716dcba8b3952fe0f7a776356800cce7 Mon Sep 17 00:00:00 2001 From: xmdhs Date: Fri, 8 Sep 2023 01:29:15 +0800 Subject: [PATCH] =?UTF-8?q?jwt=20=E6=94=B9=E7=94=A8=20rsa=20=E7=AD=BE?= =?UTF-8?q?=E5=90=8D?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- config/config.go | 6 +++--- server/provide.go | 12 +++++++++++- server/wire_gen.go | 8 +++++++- service/utils/auth.go | 5 +++-- service/yggdrasil/user.go | 10 +++++----- service/yggdrasil/yggdrasil.go | 11 +++++++---- utils/sign/rsa.go | 4 ++++ 7 files changed, 40 insertions(+), 16 deletions(-) diff --git a/config/config.go b/config/config.go index a651025..db39102 100644 --- a/config/config.go +++ b/config/config.go @@ -10,12 +10,12 @@ type Config struct { Sql struct { MysqlDsn string } - Debug bool - JwtKey string - Cache struct { + Debug bool + Cache struct { Type string Ram int } RaelIP bool MaxIpUser int + RsaPriKey string } diff --git a/server/provide.go b/server/provide.go index de5c037..439c1fa 100644 --- a/server/provide.go +++ b/server/provide.go @@ -2,6 +2,7 @@ package server import ( "context" + "crypto/rsa" "database/sql" "fmt" "log/slog" @@ -15,6 +16,7 @@ import ( "github.com/xmdhs/authlib-skin/db/cache" "github.com/xmdhs/authlib-skin/db/ent" "github.com/xmdhs/authlib-skin/db/ent/migrate" + "github.com/xmdhs/authlib-skin/utils/sign" ) func ProvideSlog(c config.Config) slog.Handler { @@ -77,4 +79,12 @@ func ProvideCache(c config.Config) cache.Cache { return cache.NewFastCache(c.Cache.Ram) } -var Set = wire.NewSet(ProvideSlog, ProvideDB, ProvideEnt, ProvideValidate, ProvideCache) +func ProvidePriKey(c config.Config) (*rsa.PrivateKey, error) { + a, err := sign.NewAuthlibSign([]byte(c.RsaPriKey)) + if err != nil { + return nil, fmt.Errorf("ProvidePriKey: %w", err) + } + return a.GetKey(), nil +} + +var Set = wire.NewSet(ProvideSlog, ProvideDB, ProvideEnt, ProvideValidate, ProvideCache, ProvidePriKey) diff --git a/server/wire_gen.go b/server/wire_gen.go index 1bdcc1c..03d63b8 100644 --- a/server/wire_gen.go +++ b/server/wire_gen.go @@ -37,7 +37,13 @@ func InitializeRoute(ctx context.Context, c config.Config) (*http.Server, func() return nil, nil, err } cache := ProvideCache(c) - yggdrasilYggdrasil := yggdrasil.NewYggdrasil(client, cache, c) + privateKey, err := ProvidePriKey(c) + if err != nil { + cleanup2() + cleanup() + return nil, nil, err + } + yggdrasilYggdrasil := yggdrasil.NewYggdrasil(client, cache, c, privateKey) yggdrasil3 := yggdrasil2.NewYggdrasil(logger, validate, yggdrasilYggdrasil) webService := service.NewWebService(c, client) handel := handle.NewHandel(webService, validate, c, logger) diff --git a/service/utils/auth.go b/service/utils/auth.go index 10aada7..1c824f2 100644 --- a/service/utils/auth.go +++ b/service/utils/auth.go @@ -2,6 +2,7 @@ package utils import ( "context" + "crypto/rsa" "errors" "fmt" "strconv" @@ -18,9 +19,9 @@ var ( ErrTokenInvalid = errors.New("token 无效") ) -func Auth(ctx context.Context, t yggdrasil.ValidateToken, client *ent.Client, jwtKey string, tmpInvalid bool) (*model.TokenClaims, error) { +func Auth(ctx context.Context, t yggdrasil.ValidateToken, client *ent.Client, pubkey *rsa.PublicKey, tmpInvalid bool) (*model.TokenClaims, error) { token, err := jwt.ParseWithClaims(t.AccessToken, &model.TokenClaims{}, func(t *jwt.Token) (interface{}, error) { - return []byte(jwtKey), nil + return pubkey, nil }) if err != nil { return nil, fmt.Errorf("Auth: %w", err) diff --git a/service/yggdrasil/user.go b/service/yggdrasil/user.go index 069d281..6578520 100644 --- a/service/yggdrasil/user.go +++ b/service/yggdrasil/user.go @@ -79,7 +79,7 @@ func (y *Yggdrasil) Authenticate(cxt context.Context, auth yggdrasil.Authenticat return yggdrasil.Token{}, fmt.Errorf("Authenticate: %w", err) } - jwts, err := newJwtToken(y.config.JwtKey, strconv.FormatUint(utoken.TokenID, 10), clientToken, u.Edges.Profile.UUID) + jwts, err := newJwtToken(y.prikey, strconv.FormatUint(utoken.TokenID, 10), clientToken, u.Edges.Profile.UUID) if err != nil { return yggdrasil.Token{}, fmt.Errorf("Authenticate: %w", err) } @@ -101,7 +101,7 @@ func (y *Yggdrasil) Authenticate(cxt context.Context, auth yggdrasil.Authenticat } func (y *Yggdrasil) ValidateToken(ctx context.Context, t yggdrasil.ValidateToken) error { - _, err := sutils.Auth(ctx, t, y.client, y.config.JwtKey, true) + _, err := sutils.Auth(ctx, t, y.client, &y.prikey.PublicKey, true) if err != nil { return fmt.Errorf("ValidateToken: %w", err) } @@ -129,7 +129,7 @@ func (y *Yggdrasil) SignOut(ctx context.Context, t yggdrasil.Pass) error { } func (y *Yggdrasil) Invalidate(ctx context.Context, accessToken string) error { - t, err := sutils.Auth(ctx, yggdrasil.ValidateToken{AccessToken: accessToken}, y.client, y.config.JwtKey, true) + t, err := sutils.Auth(ctx, yggdrasil.ValidateToken{AccessToken: accessToken}, y.client, &y.prikey.PublicKey, true) if err != nil { return fmt.Errorf("Invalidate: %w", err) } @@ -141,11 +141,11 @@ func (y *Yggdrasil) Invalidate(ctx context.Context, accessToken string) error { } func (y *Yggdrasil) Refresh(ctx context.Context, token yggdrasil.RefreshToken) (yggdrasil.Token, error) { - t, err := sutils.Auth(ctx, yggdrasil.ValidateToken{AccessToken: token.AccessToken, ClientToken: token.ClientToken}, y.client, y.config.JwtKey, false) + t, err := sutils.Auth(ctx, yggdrasil.ValidateToken{AccessToken: token.AccessToken, ClientToken: token.ClientToken}, y.client, &y.prikey.PublicKey, false) if err != nil { return yggdrasil.Token{}, fmt.Errorf("Refresh: %w", err) } - jwts, err := newJwtToken(y.config.JwtKey, t.Tid, t.CID, t.Subject) + jwts, err := newJwtToken(y.prikey, t.Tid, t.CID, t.Subject) if err != nil { return yggdrasil.Token{}, fmt.Errorf("Authenticate: %w", err) } diff --git a/service/yggdrasil/yggdrasil.go b/service/yggdrasil/yggdrasil.go index a355989..c15621d 100644 --- a/service/yggdrasil/yggdrasil.go +++ b/service/yggdrasil/yggdrasil.go @@ -1,6 +1,7 @@ package yggdrasil import ( + "crypto/rsa" "encoding/binary" "fmt" "time" @@ -16,13 +17,15 @@ type Yggdrasil struct { client *ent.Client cache cache.Cache config config.Config + prikey *rsa.PrivateKey } -func NewYggdrasil(client *ent.Client, cache cache.Cache, c config.Config) *Yggdrasil { +func NewYggdrasil(client *ent.Client, cache cache.Cache, c config.Config, prikey *rsa.PrivateKey) *Yggdrasil { return &Yggdrasil{ client: client, cache: cache, config: c, + prikey: prikey, } } @@ -60,7 +63,7 @@ func putUint(n uint64, c cache.Cache, key []byte, d time.Duration) error { return nil } -func newJwtToken(jwtKey string, tokenID, clientToken, UUID string) (string, error) { +func newJwtToken(jwtKey *rsa.PrivateKey, tokenID, clientToken, UUID string) (string, error) { claims := model.TokenClaims{ Tid: tokenID, CID: clientToken, @@ -71,8 +74,8 @@ func newJwtToken(jwtKey string, tokenID, clientToken, UUID string) (string, erro IssuedAt: jwt.NewNumericDate(time.Now()), }, } - token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims) - jwts, err := token.SignedString([]byte(jwtKey)) + token := jwt.NewWithClaims(jwt.SigningMethodRS256, claims) + jwts, err := token.SignedString(jwtKey) if err != nil { return "", fmt.Errorf("newJwtToken: %w", err) } diff --git a/utils/sign/rsa.go b/utils/sign/rsa.go index 31be391..7b854f1 100644 --- a/utils/sign/rsa.go +++ b/utils/sign/rsa.go @@ -38,6 +38,10 @@ func NewAuthlibSignWithKey(key *rsa.PrivateKey) *AuthlibSign { } } +func (a *AuthlibSign) GetKey() *rsa.PrivateKey { + return a.key +} + func (a *AuthlibSign) GetPubKey() (string, error) { derBytes := x509.MarshalPKCS1PublicKey(&a.key.PublicKey) pemKey := &pem.Block{