jwt 改用 rsa 签名

This commit is contained in:
xmdhs 2023-09-08 01:29:15 +08:00
parent 90a568ab4e
commit 083d859871
No known key found for this signature in database
GPG Key ID: E809D6D43DEFCC95
7 changed files with 40 additions and 16 deletions

View File

@ -11,11 +11,11 @@ type Config struct {
MysqlDsn string MysqlDsn string
} }
Debug bool Debug bool
JwtKey string
Cache struct { Cache struct {
Type string Type string
Ram int Ram int
} }
RaelIP bool RaelIP bool
MaxIpUser int MaxIpUser int
RsaPriKey string
} }

View File

@ -2,6 +2,7 @@ package server
import ( import (
"context" "context"
"crypto/rsa"
"database/sql" "database/sql"
"fmt" "fmt"
"log/slog" "log/slog"
@ -15,6 +16,7 @@ import (
"github.com/xmdhs/authlib-skin/db/cache" "github.com/xmdhs/authlib-skin/db/cache"
"github.com/xmdhs/authlib-skin/db/ent" "github.com/xmdhs/authlib-skin/db/ent"
"github.com/xmdhs/authlib-skin/db/ent/migrate" "github.com/xmdhs/authlib-skin/db/ent/migrate"
"github.com/xmdhs/authlib-skin/utils/sign"
) )
func ProvideSlog(c config.Config) slog.Handler { func ProvideSlog(c config.Config) slog.Handler {
@ -77,4 +79,12 @@ func ProvideCache(c config.Config) cache.Cache {
return cache.NewFastCache(c.Cache.Ram) return cache.NewFastCache(c.Cache.Ram)
} }
var Set = wire.NewSet(ProvideSlog, ProvideDB, ProvideEnt, ProvideValidate, ProvideCache) func ProvidePriKey(c config.Config) (*rsa.PrivateKey, error) {
a, err := sign.NewAuthlibSign([]byte(c.RsaPriKey))
if err != nil {
return nil, fmt.Errorf("ProvidePriKey: %w", err)
}
return a.GetKey(), nil
}
var Set = wire.NewSet(ProvideSlog, ProvideDB, ProvideEnt, ProvideValidate, ProvideCache, ProvidePriKey)

View File

@ -37,7 +37,13 @@ func InitializeRoute(ctx context.Context, c config.Config) (*http.Server, func()
return nil, nil, err return nil, nil, err
} }
cache := ProvideCache(c) cache := ProvideCache(c)
yggdrasilYggdrasil := yggdrasil.NewYggdrasil(client, cache, c) privateKey, err := ProvidePriKey(c)
if err != nil {
cleanup2()
cleanup()
return nil, nil, err
}
yggdrasilYggdrasil := yggdrasil.NewYggdrasil(client, cache, c, privateKey)
yggdrasil3 := yggdrasil2.NewYggdrasil(logger, validate, yggdrasilYggdrasil) yggdrasil3 := yggdrasil2.NewYggdrasil(logger, validate, yggdrasilYggdrasil)
webService := service.NewWebService(c, client) webService := service.NewWebService(c, client)
handel := handle.NewHandel(webService, validate, c, logger) handel := handle.NewHandel(webService, validate, c, logger)

View File

@ -2,6 +2,7 @@ package utils
import ( import (
"context" "context"
"crypto/rsa"
"errors" "errors"
"fmt" "fmt"
"strconv" "strconv"
@ -18,9 +19,9 @@ var (
ErrTokenInvalid = errors.New("token 无效") ErrTokenInvalid = errors.New("token 无效")
) )
func Auth(ctx context.Context, t yggdrasil.ValidateToken, client *ent.Client, jwtKey string, tmpInvalid bool) (*model.TokenClaims, error) { func Auth(ctx context.Context, t yggdrasil.ValidateToken, client *ent.Client, pubkey *rsa.PublicKey, tmpInvalid bool) (*model.TokenClaims, error) {
token, err := jwt.ParseWithClaims(t.AccessToken, &model.TokenClaims{}, func(t *jwt.Token) (interface{}, error) { token, err := jwt.ParseWithClaims(t.AccessToken, &model.TokenClaims{}, func(t *jwt.Token) (interface{}, error) {
return []byte(jwtKey), nil return pubkey, nil
}) })
if err != nil { if err != nil {
return nil, fmt.Errorf("Auth: %w", err) return nil, fmt.Errorf("Auth: %w", err)

View File

@ -79,7 +79,7 @@ func (y *Yggdrasil) Authenticate(cxt context.Context, auth yggdrasil.Authenticat
return yggdrasil.Token{}, fmt.Errorf("Authenticate: %w", err) return yggdrasil.Token{}, fmt.Errorf("Authenticate: %w", err)
} }
jwts, err := newJwtToken(y.config.JwtKey, strconv.FormatUint(utoken.TokenID, 10), clientToken, u.Edges.Profile.UUID) jwts, err := newJwtToken(y.prikey, strconv.FormatUint(utoken.TokenID, 10), clientToken, u.Edges.Profile.UUID)
if err != nil { if err != nil {
return yggdrasil.Token{}, fmt.Errorf("Authenticate: %w", err) return yggdrasil.Token{}, fmt.Errorf("Authenticate: %w", err)
} }
@ -101,7 +101,7 @@ func (y *Yggdrasil) Authenticate(cxt context.Context, auth yggdrasil.Authenticat
} }
func (y *Yggdrasil) ValidateToken(ctx context.Context, t yggdrasil.ValidateToken) error { func (y *Yggdrasil) ValidateToken(ctx context.Context, t yggdrasil.ValidateToken) error {
_, err := sutils.Auth(ctx, t, y.client, y.config.JwtKey, true) _, err := sutils.Auth(ctx, t, y.client, &y.prikey.PublicKey, true)
if err != nil { if err != nil {
return fmt.Errorf("ValidateToken: %w", err) return fmt.Errorf("ValidateToken: %w", err)
} }
@ -129,7 +129,7 @@ func (y *Yggdrasil) SignOut(ctx context.Context, t yggdrasil.Pass) error {
} }
func (y *Yggdrasil) Invalidate(ctx context.Context, accessToken string) error { func (y *Yggdrasil) Invalidate(ctx context.Context, accessToken string) error {
t, err := sutils.Auth(ctx, yggdrasil.ValidateToken{AccessToken: accessToken}, y.client, y.config.JwtKey, true) t, err := sutils.Auth(ctx, yggdrasil.ValidateToken{AccessToken: accessToken}, y.client, &y.prikey.PublicKey, true)
if err != nil { if err != nil {
return fmt.Errorf("Invalidate: %w", err) return fmt.Errorf("Invalidate: %w", err)
} }
@ -141,11 +141,11 @@ func (y *Yggdrasil) Invalidate(ctx context.Context, accessToken string) error {
} }
func (y *Yggdrasil) Refresh(ctx context.Context, token yggdrasil.RefreshToken) (yggdrasil.Token, error) { func (y *Yggdrasil) Refresh(ctx context.Context, token yggdrasil.RefreshToken) (yggdrasil.Token, error) {
t, err := sutils.Auth(ctx, yggdrasil.ValidateToken{AccessToken: token.AccessToken, ClientToken: token.ClientToken}, y.client, y.config.JwtKey, false) t, err := sutils.Auth(ctx, yggdrasil.ValidateToken{AccessToken: token.AccessToken, ClientToken: token.ClientToken}, y.client, &y.prikey.PublicKey, false)
if err != nil { if err != nil {
return yggdrasil.Token{}, fmt.Errorf("Refresh: %w", err) return yggdrasil.Token{}, fmt.Errorf("Refresh: %w", err)
} }
jwts, err := newJwtToken(y.config.JwtKey, t.Tid, t.CID, t.Subject) jwts, err := newJwtToken(y.prikey, t.Tid, t.CID, t.Subject)
if err != nil { if err != nil {
return yggdrasil.Token{}, fmt.Errorf("Authenticate: %w", err) return yggdrasil.Token{}, fmt.Errorf("Authenticate: %w", err)
} }

View File

@ -1,6 +1,7 @@
package yggdrasil package yggdrasil
import ( import (
"crypto/rsa"
"encoding/binary" "encoding/binary"
"fmt" "fmt"
"time" "time"
@ -16,13 +17,15 @@ type Yggdrasil struct {
client *ent.Client client *ent.Client
cache cache.Cache cache cache.Cache
config config.Config config config.Config
prikey *rsa.PrivateKey
} }
func NewYggdrasil(client *ent.Client, cache cache.Cache, c config.Config) *Yggdrasil { func NewYggdrasil(client *ent.Client, cache cache.Cache, c config.Config, prikey *rsa.PrivateKey) *Yggdrasil {
return &Yggdrasil{ return &Yggdrasil{
client: client, client: client,
cache: cache, cache: cache,
config: c, config: c,
prikey: prikey,
} }
} }
@ -60,7 +63,7 @@ func putUint(n uint64, c cache.Cache, key []byte, d time.Duration) error {
return nil return nil
} }
func newJwtToken(jwtKey string, tokenID, clientToken, UUID string) (string, error) { func newJwtToken(jwtKey *rsa.PrivateKey, tokenID, clientToken, UUID string) (string, error) {
claims := model.TokenClaims{ claims := model.TokenClaims{
Tid: tokenID, Tid: tokenID,
CID: clientToken, CID: clientToken,
@ -71,8 +74,8 @@ func newJwtToken(jwtKey string, tokenID, clientToken, UUID string) (string, erro
IssuedAt: jwt.NewNumericDate(time.Now()), IssuedAt: jwt.NewNumericDate(time.Now()),
}, },
} }
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims) token := jwt.NewWithClaims(jwt.SigningMethodRS256, claims)
jwts, err := token.SignedString([]byte(jwtKey)) jwts, err := token.SignedString(jwtKey)
if err != nil { if err != nil {
return "", fmt.Errorf("newJwtToken: %w", err) return "", fmt.Errorf("newJwtToken: %w", err)
} }

View File

@ -38,6 +38,10 @@ func NewAuthlibSignWithKey(key *rsa.PrivateKey) *AuthlibSign {
} }
} }
func (a *AuthlibSign) GetKey() *rsa.PrivateKey {
return a.key
}
func (a *AuthlibSign) GetPubKey() (string, error) { func (a *AuthlibSign) GetPubKey() (string, error) {
derBytes := x509.MarshalPKCS1PublicKey(&a.key.PublicKey) derBytes := x509.MarshalPKCS1PublicKey(&a.key.PublicKey)
pemKey := &pem.Block{ pemKey := &pem.Block{