From 109f284a11df5556a43607d5ab203f738834d804 Mon Sep 17 00:00:00 2001 From: xmdhs Date: Sat, 2 Sep 2023 00:33:00 +0800 Subject: [PATCH] =?UTF-8?q?=E6=B3=A8=E5=86=8C=E4=BB=A3=E7=A0=81?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- config/config.go | 5 ++ db/mysql/db.go | 44 +++++++++++++----- db/mysql/models.go | 20 +++++++- db/mysql/querier.go | 2 + db/mysql/query.sql.go | 77 ++++++++++++++++++++++--------- db/mysql/sql/query.sql | 38 ++++++++++----- db/mysql/sql/schema.sql | 1 + go.mod | 2 + go.sum | 4 ++ handle/user.go | 24 ++++++++++ handle/yggdrasil/authenticate.go | 16 ++----- model/model.go | 6 +++ server/slog.go | 42 +++++++++++++++++ service/user.go | 79 ++++++++++++++++++++++++++++++++ utils/argon2id.go | 28 +++++++++++ utils/decode.go | 24 ++++++++++ utils/tx.go | 37 +++++++++++++++ 17 files changed, 390 insertions(+), 59 deletions(-) create mode 100644 config/config.go create mode 100644 server/slog.go create mode 100644 service/user.go create mode 100644 utils/argon2id.go create mode 100644 utils/decode.go create mode 100644 utils/tx.go diff --git a/config/config.go b/config/config.go new file mode 100644 index 0000000..976e68c --- /dev/null +++ b/config/config.go @@ -0,0 +1,5 @@ +package config + +type Config struct { + OfflineUUID bool +} diff --git a/db/mysql/db.go b/db/mysql/db.go index 5240067..48af8b0 100644 --- a/db/mysql/db.go +++ b/db/mysql/db.go @@ -25,12 +25,18 @@ func Prepare(ctx context.Context, db DBTX) (*Queries, error) { if q.createUserStmt, err = db.PrepareContext(ctx, createUser); err != nil { return nil, fmt.Errorf("error preparing query CreateUser: %w", err) } + if q.createUserProfileStmt, err = db.PrepareContext(ctx, createUserProfile); err != nil { + return nil, fmt.Errorf("error preparing query CreateUserProfile: %w", err) + } if q.deleteUserStmt, err = db.PrepareContext(ctx, deleteUser); err != nil { return nil, fmt.Errorf("error preparing query DeleteUser: %w", err) } if q.getUserStmt, err = db.PrepareContext(ctx, getUser); err != nil { return nil, fmt.Errorf("error preparing query GetUser: %w", err) } + if q.getUserByEmailStmt, err = db.PrepareContext(ctx, getUserByEmail); err != nil { + return nil, fmt.Errorf("error preparing query GetUserByEmail: %w", err) + } if q.listUserStmt, err = db.PrepareContext(ctx, listUser); err != nil { return nil, fmt.Errorf("error preparing query ListUser: %w", err) } @@ -44,6 +50,11 @@ func (q *Queries) Close() error { err = fmt.Errorf("error closing createUserStmt: %w", cerr) } } + if q.createUserProfileStmt != nil { + if cerr := q.createUserProfileStmt.Close(); cerr != nil { + err = fmt.Errorf("error closing createUserProfileStmt: %w", cerr) + } + } if q.deleteUserStmt != nil { if cerr := q.deleteUserStmt.Close(); cerr != nil { err = fmt.Errorf("error closing deleteUserStmt: %w", cerr) @@ -54,6 +65,11 @@ func (q *Queries) Close() error { err = fmt.Errorf("error closing getUserStmt: %w", cerr) } } + if q.getUserByEmailStmt != nil { + if cerr := q.getUserByEmailStmt.Close(); cerr != nil { + err = fmt.Errorf("error closing getUserByEmailStmt: %w", cerr) + } + } if q.listUserStmt != nil { if cerr := q.listUserStmt.Close(); cerr != nil { err = fmt.Errorf("error closing listUserStmt: %w", cerr) @@ -96,21 +112,25 @@ func (q *Queries) queryRow(ctx context.Context, stmt *sql.Stmt, query string, ar } type Queries struct { - db DBTX - tx *sql.Tx - createUserStmt *sql.Stmt - deleteUserStmt *sql.Stmt - getUserStmt *sql.Stmt - listUserStmt *sql.Stmt + db DBTX + tx *sql.Tx + createUserStmt *sql.Stmt + createUserProfileStmt *sql.Stmt + deleteUserStmt *sql.Stmt + getUserStmt *sql.Stmt + getUserByEmailStmt *sql.Stmt + listUserStmt *sql.Stmt } func (q *Queries) WithTx(tx *sql.Tx) *Queries { return &Queries{ - db: tx, - tx: tx, - createUserStmt: q.createUserStmt, - deleteUserStmt: q.deleteUserStmt, - getUserStmt: q.getUserStmt, - listUserStmt: q.listUserStmt, + db: tx, + tx: tx, + createUserStmt: q.createUserStmt, + createUserProfileStmt: q.createUserProfileStmt, + deleteUserStmt: q.deleteUserStmt, + getUserStmt: q.getUserStmt, + getUserByEmailStmt: q.getUserByEmailStmt, + listUserStmt: q.listUserStmt, } } diff --git a/db/mysql/models.go b/db/mysql/models.go index bb85e9e..7fbfccb 100644 --- a/db/mysql/models.go +++ b/db/mysql/models.go @@ -5,6 +5,7 @@ package mysql import () type Skin struct { + ID int64 `db:"id"` UserID int64 `db:"user_id"` SkinHash string `db:"skin_hash"` Type string `db:"type"` @@ -16,7 +17,22 @@ type User struct { Email string `db:"email"` Password string `db:"password"` Salt string `db:"salt"` - Disabled int32 `db:"disabled"` - Admin int32 `db:"admin"` + State int32 `db:"state"` RegTime int64 `db:"reg_time"` } + +type UserProfile struct { + UserID int64 `db:"user_id"` + Name string `db:"name"` + Uuid string `db:"uuid"` +} + +type UserSkin struct { + UserID int64 `db:"user_id"` + SkinID int64 `db:"skin_id"` +} + +type UserToken struct { + UserID int64 `db:"user_id"` + TokenID int32 `db:"token_id"` +} diff --git a/db/mysql/querier.go b/db/mysql/querier.go index 6348d4d..157f94a 100644 --- a/db/mysql/querier.go +++ b/db/mysql/querier.go @@ -9,8 +9,10 @@ import ( type Querier interface { CreateUser(ctx context.Context, arg CreateUserParams) (sql.Result, error) + CreateUserProfile(ctx context.Context, arg CreateUserProfileParams) (sql.Result, error) DeleteUser(ctx context.Context, id int64) error GetUser(ctx context.Context, id int64) (User, error) + GetUserByEmail(ctx context.Context, email string) (User, error) ListUser(ctx context.Context) ([]User, error) } diff --git a/db/mysql/query.sql.go b/db/mysql/query.sql.go index a3b028f..5c2fbcd 100644 --- a/db/mysql/query.sql.go +++ b/db/mysql/query.sql.go @@ -9,18 +9,16 @@ import ( ) const createUser = `-- name: CreateUser :execresult -INSERT INTO - user ( - id, - email, - password, - salt, - disabled, - admin, - reg_time - ) +REPLACE INTO user ( + id, + email, + password, + salt, + state, + reg_time +) VALUES - (?, ?, ?, ?, ?, ?, ?) + (?, ?, ?, ?, ?, ?) ` type CreateUserParams struct { @@ -28,8 +26,7 @@ type CreateUserParams struct { Email string `db:"email"` Password string `db:"password"` Salt string `db:"salt"` - Disabled int32 `db:"disabled"` - Admin int32 `db:"admin"` + State int32 `db:"state"` RegTime int64 `db:"reg_time"` } @@ -39,12 +36,27 @@ func (q *Queries) CreateUser(ctx context.Context, arg CreateUserParams) (sql.Res arg.Email, arg.Password, arg.Salt, - arg.Disabled, - arg.Admin, + arg.State, arg.RegTime, ) } +const createUserProfile = `-- name: CreateUserProfile :execresult +REPLACE INTO ` + "`" + `user_profile` + "`" + ` (` + "`" + `user_id` + "`" + `, ` + "`" + `name` + "`" + `, ` + "`" + `uuid` + "`" + `) +VALUES + (?, ?, ?) +` + +type CreateUserProfileParams struct { + UserID int64 `db:"user_id"` + Name string `db:"name"` + Uuid string `db:"uuid"` +} + +func (q *Queries) CreateUserProfile(ctx context.Context, arg CreateUserProfileParams) (sql.Result, error) { + return q.exec(ctx, q.createUserProfileStmt, createUserProfile, arg.UserID, arg.Name, arg.Uuid) +} + const deleteUser = `-- name: DeleteUser :exec DELETE FROM user @@ -59,7 +71,7 @@ func (q *Queries) DeleteUser(ctx context.Context, id int64) error { const getUser = `-- name: GetUser :one SELECT - id, email, password, salt, disabled, admin, reg_time + id, email, password, salt, state, reg_time FROM user WHERE @@ -76,8 +88,32 @@ func (q *Queries) GetUser(ctx context.Context, id int64) (User, error) { &i.Email, &i.Password, &i.Salt, - &i.Disabled, - &i.Admin, + &i.State, + &i.RegTime, + ) + return i, err +} + +const getUserByEmail = `-- name: GetUserByEmail :one +SELECT + id, email, password, salt, state, reg_time +FROM + user +WHERE + email = ? +LIMIT + 1 +` + +func (q *Queries) GetUserByEmail(ctx context.Context, email string) (User, error) { + row := q.queryRow(ctx, q.getUserByEmailStmt, getUserByEmail, email) + var i User + err := row.Scan( + &i.ID, + &i.Email, + &i.Password, + &i.Salt, + &i.State, &i.RegTime, ) return i, err @@ -85,7 +121,7 @@ func (q *Queries) GetUser(ctx context.Context, id int64) (User, error) { const listUser = `-- name: ListUser :many SELECT - id, email, password, salt, disabled, admin, reg_time + id, email, password, salt, state, reg_time FROM user ORDER BY @@ -106,8 +142,7 @@ func (q *Queries) ListUser(ctx context.Context) ([]User, error) { &i.Email, &i.Password, &i.Salt, - &i.Disabled, - &i.Admin, + &i.State, &i.RegTime, ); err != nil { return nil, err diff --git a/db/mysql/sql/query.sql b/db/mysql/sql/query.sql index 8424210..fa3872c 100644 --- a/db/mysql/sql/query.sql +++ b/db/mysql/sql/query.sql @@ -17,21 +17,35 @@ ORDER BY reg_time; -- name: CreateUser :execresult -INSERT INTO - user ( - id, - email, - password, - salt, - disabled, - admin, - reg_time - ) +REPLACE INTO user ( + id, + email, + password, + salt, + state, + reg_time +) VALUES - (?, ?, ?, ?, ?, ?, ?); + (?, ?, ?, ?, ?, ?); -- name: DeleteUser :exec DELETE FROM user WHERE - id = ?; \ No newline at end of file + id = ?; + +-- name: CreateUserProfile :execresult +REPLACE INTO `user_profile` (`user_id`, `name`, `uuid`) +VALUES + (?, ?, ?); + + +-- name: GetUserByEmail :one +SELECT + * +FROM + user +WHERE + email = ? +LIMIT + 1; \ No newline at end of file diff --git a/db/mysql/sql/schema.sql b/db/mysql/sql/schema.sql index 2e9da0d..b7929cb 100644 --- a/db/mysql/sql/schema.sql +++ b/db/mysql/sql/schema.sql @@ -3,6 +3,7 @@ CREATE TABLE IF NOT EXISTS `user` ( email VARCHAR(20) NOT NULL, password text NOT NULL, salt text NOT NULL, + -- 二进制状态位,暂无作用 state INT NOT NULL, reg_time BIGINT NOT NULL ); diff --git a/go.mod b/go.mod index c68817a..f66c92c 100644 --- a/go.mod +++ b/go.mod @@ -8,9 +8,11 @@ require ( ) require ( + github.com/bwmarrin/snowflake v0.3.0 // indirect github.com/gabriel-vasile/mimetype v1.4.2 // indirect github.com/go-playground/locales v0.14.1 // indirect github.com/go-playground/universal-translator v0.18.1 // indirect + github.com/google/uuid v1.3.1 // indirect github.com/leodido/go-urn v1.2.4 // indirect golang.org/x/crypto v0.7.0 // indirect golang.org/x/net v0.8.0 // indirect diff --git a/go.sum b/go.sum index 11841cc..6eaad15 100644 --- a/go.sum +++ b/go.sum @@ -1,3 +1,5 @@ +github.com/bwmarrin/snowflake v0.3.0 h1:xm67bEhkKh6ij1790JB83OujPR5CzNe8QuQqAgISZN0= +github.com/bwmarrin/snowflake v0.3.0/go.mod h1:NdZxfVWX+oR6y2K0o6qAYv6gIOP9rjG0/E9WsDpxqwE= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= @@ -11,6 +13,8 @@ github.com/go-playground/universal-translator v0.18.1 h1:Bcnm0ZwsGyWbCzImXv+pAJn github.com/go-playground/universal-translator v0.18.1/go.mod h1:xekY+UJKNuX9WP91TpwSH2VMlDf28Uj24BCp08ZFTUY= github.com/go-playground/validator/v10 v10.15.3 h1:S+sSpunYjNPDuXkWbK+x+bA7iXiW296KG4dL3X7xUZo= github.com/go-playground/validator/v10 v10.15.3/go.mod h1:9iXMNT7sEkjXb0I+enO7QXmzG6QCsPWY4zveKFVRSyU= +github.com/google/uuid v1.3.1 h1:KjJaJ9iWZ3jOFZIf1Lqf4laDRCasjl0BCmnEGxkdLb4= +github.com/google/uuid v1.3.1/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/julienschmidt/httprouter v1.3.0 h1:U0609e9tgbseu3rBINet9P48AI/D3oJs4dN7jwJOQ1U= github.com/julienschmidt/httprouter v1.3.0/go.mod h1:JR6WtHb+2LUe8TCKY3cZOxFyyO8IZAc4RVcycCCAKdM= github.com/leodido/go-urn v1.2.4 h1:XlAE/cm/ms7TE/VMVoduSpNBoyc2dOxHs5MZSwAN63Q= diff --git a/handle/user.go b/handle/user.go index 944bf07..2477f74 100644 --- a/handle/user.go +++ b/handle/user.go @@ -1 +1,25 @@ package handle + +import ( + "log/slog" + "net/http" + + "github.com/go-playground/validator/v10" + "github.com/julienschmidt/httprouter" + "github.com/xmdhs/authlib-skin/db/mysql" + "github.com/xmdhs/authlib-skin/model" + "github.com/xmdhs/authlib-skin/utils" +) + +func Reg(l *slog.Logger, q mysql.Querier, v *validator.Validate) httprouter.Handle { + return func(w http.ResponseWriter, r *http.Request, p httprouter.Params) { + ctx := r.Context() + + u, err := utils.DeCodeBody[model.User](r.Body, v) + if err != nil { + l.InfoContext(ctx, err.Error()) + } + _ = u + + } +} diff --git a/handle/yggdrasil/authenticate.go b/handle/yggdrasil/authenticate.go index e9fb0fa..8d4782b 100644 --- a/handle/yggdrasil/authenticate.go +++ b/handle/yggdrasil/authenticate.go @@ -1,33 +1,25 @@ package yggdrasil import ( - "encoding/json" "log/slog" "net/http" "github.com/go-playground/validator/v10" "github.com/julienschmidt/httprouter" "github.com/xmdhs/authlib-skin/model/yggdrasil" + "github.com/xmdhs/authlib-skin/utils" ) func Authenticate(l *slog.Logger, v *validator.Validate) httprouter.Handle { return func(w http.ResponseWriter, r *http.Request, p httprouter.Params) { cxt := r.Context() - jr := json.NewDecoder(r.Body) - var a yggdrasil.Authenticate - err := jr.Decode(&a) + a, err := utils.DeCodeBody[yggdrasil.Authenticate](r.Body, v) if err != nil { - l.Info(err.Error()) - handleYgError(cxt, w, yggdrasil.Error{ErrorMessage: err.Error()}, 400) - return - } - - err = v.Struct(a) - if err != nil { - l.Info(err.Error()) + l.InfoContext(cxt, err.Error()) handleYgError(cxt, w, yggdrasil.Error{ErrorMessage: err.Error()}, 400) return } + _ = a } } diff --git a/model/model.go b/model/model.go index ba0a8f4..75e1c1d 100644 --- a/model/model.go +++ b/model/model.go @@ -5,3 +5,9 @@ type API[T any] struct { Data T `json:"data"` Msg string `json:"msg"` } + +type User struct { + Email string `validate:"required,email"` + Password string `validate:"required,sha256"` + Name string `validate:"required,min=3,max=16"` +} diff --git a/server/slog.go b/server/slog.go new file mode 100644 index 0000000..e0a47c5 --- /dev/null +++ b/server/slog.go @@ -0,0 +1,42 @@ +package server + +import ( + "context" + "log/slog" +) + +type reqInfo struct { + URL string + IP string + TrackId uint64 +} + +type reqInfoKeyType string + +var reqinfoKey reqInfoKeyType = "reqinfoKey" + +func setCtx(ctx context.Context, r *reqInfo) context.Context { + return context.WithValue(ctx, reqinfoKey, r) +} + +func getFromCtx(ctx context.Context) *reqInfo { + v := ctx.Value(reqinfoKey) + if v == nil { + return nil + } + return v.(*reqInfo) +} + +type warpSlogHandle struct { + slog.Handler +} + +func (w *warpSlogHandle) Handle(ctx context.Context, r slog.Record) error { + if w.Enabled(ctx, slog.LevelDebug) { + ri := getFromCtx(ctx) + if ri != nil { + r.AddAttrs(slog.String("ip", ri.IP), slog.String("url", ri.URL), slog.Uint64("trackID", ri.TrackId)) + } + } + return w.Handler.Handle(ctx, r) +} diff --git a/service/user.go b/service/user.go new file mode 100644 index 0000000..397e08d --- /dev/null +++ b/service/user.go @@ -0,0 +1,79 @@ +package service + +import ( + "context" + "crypto/md5" + "database/sql" + "encoding/hex" + "errors" + "fmt" + "strings" + "time" + + "github.com/bwmarrin/snowflake" + "github.com/google/uuid" + "github.com/xmdhs/authlib-skin/config" + "github.com/xmdhs/authlib-skin/db/mysql" + "github.com/xmdhs/authlib-skin/model" + "github.com/xmdhs/authlib-skin/utils" +) + +var ErrExistUser = errors.New("用户已存在") + +func Reg(ctx context.Context, u model.User, q mysql.Querier, db *sql.DB, snow *snowflake.Node, + c config.Config, +) error { + ou, err := q.GetUserByEmail(ctx, u.Email) + if err != nil && !errors.Is(err, sql.ErrNoRows) { + return fmt.Errorf("Reg: %w", err) + } + if ou.Email != "" { + return fmt.Errorf("Reg: %w", ErrExistUser) + } + err = utils.WithTx(ctx, &sql.TxOptions{}, q, db, func(q mysql.Querier) error { + p, s := utils.Argon2ID(u.Password) + userID := snow.Generate().Int64() + _, err := q.CreateUser(ctx, mysql.CreateUserParams{ + ID: userID, + Email: u.Email, + Password: p, + Salt: s, + State: 0, + RegTime: time.Now().Unix(), + }) + if err != nil { + return err + } + var userUuid string + if c.OfflineUUID { + userUuid = uuidGen(u.Name) + } else { + userUuid = strings.ReplaceAll(uuid.New().String(), "-", "") + } + + _, err = q.CreateUserProfile(ctx, mysql.CreateUserProfileParams{ + UserID: userID, + Name: u.Name, + Uuid: userUuid, + }) + if err != nil { + return err + } + return nil + }) + if err != nil { + return fmt.Errorf("Reg: %w", err) + } + + return nil +} + +func uuidGen(t string) string { + data := []byte("OfflinePlayer:" + t) + h := md5.New() + h.Write(data) + uuid := h.Sum(nil) + uuid[6] = (uuid[6] & 0x0f) | uint8((3&0xf)<<4) + uuid[8] = (uuid[8] & 0x3f) | 0x80 + return hex.EncodeToString(uuid) +} diff --git a/utils/argon2id.go b/utils/argon2id.go new file mode 100644 index 0000000..6e24ed0 --- /dev/null +++ b/utils/argon2id.go @@ -0,0 +1,28 @@ +package utils + +import ( + "crypto/rand" + "crypto/subtle" + "encoding/base64" + + "golang.org/x/crypto/argon2" +) + +func Argon2ID(pass string) (password string, salt string) { + s := make([]byte, 16) + _, err := rand.Read(s) + if err != nil { + panic(err) + } + b := argon2.IDKey([]byte(pass), s, 1, 64*1024, 1, 32) + return base64.StdEncoding.EncodeToString(b), base64.StdEncoding.EncodeToString(s) +} + +func Argon2Compare(pass, hashPass string, salt []byte) bool { + b := argon2.IDKey([]byte(pass), salt, 1, 64*1024, 1, 32) + hb, err := base64.StdEncoding.DecodeString(hashPass) + if err != nil { + return false + } + return subtle.ConstantTimeCompare(b, hb) == 1 +} diff --git a/utils/decode.go b/utils/decode.go new file mode 100644 index 0000000..ac00627 --- /dev/null +++ b/utils/decode.go @@ -0,0 +1,24 @@ +package utils + +import ( + "encoding/json" + "fmt" + "io" + + "github.com/go-playground/validator/v10" +) + +func DeCodeBody[T any](r io.Reader, v *validator.Validate) (T, error) { + jr := json.NewDecoder(r) + var a T + err := jr.Decode(&a) + if err != nil { + return a, fmt.Errorf("DeCodeBody: %w", err) + } + + err = v.Struct(a) + if err != nil { + return a, fmt.Errorf("DeCodeBody: %w", err) + } + return a, nil +} diff --git a/utils/tx.go b/utils/tx.go new file mode 100644 index 0000000..e517287 --- /dev/null +++ b/utils/tx.go @@ -0,0 +1,37 @@ +package utils + +import ( + "context" + "database/sql" + "fmt" + + "github.com/xmdhs/authlib-skin/db/mysql" +) + +func WithTx(ctx context.Context, opts *sql.TxOptions, q mysql.Querier, db *sql.DB, f func(mysql.Querier) error) error { + w, ok := q.(interface { + WithTx(tx *sql.Tx) *mysql.Queries + }) + var tx *sql.Tx + if ok { + fmt.Println("事务开启") // remove me + var err error + tx, err = db.BeginTx(ctx, opts) + if err != nil { + return fmt.Errorf("WithTx: %w", err) + } + defer tx.Rollback() + q = w.WithTx(tx) + } + err := f(q) + if err != nil { + return fmt.Errorf("WithTx: %w", err) + } + if tx != nil { + err := tx.Commit() + if err != nil { + return fmt.Errorf("WithTx: %w", err) + } + } + return nil +}