更改事务方式
This commit is contained in:
parent
c35e71d22c
commit
280b0ad839
@ -1,2 +1,37 @@
|
|||||||
//go:generate sqlc generate
|
//go:generate sqlc generate
|
||||||
package mysql
|
package mysql
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"database/sql"
|
||||||
|
"fmt"
|
||||||
|
)
|
||||||
|
|
||||||
|
type QuerierWithTx interface {
|
||||||
|
Querier
|
||||||
|
Tx(ctx context.Context, f func(Querier) error) error
|
||||||
|
}
|
||||||
|
|
||||||
|
var _ QuerierWithTx = (*Queries)(nil)
|
||||||
|
|
||||||
|
func (q *Queries) Tx(ctx context.Context, f func(Querier) error) error {
|
||||||
|
db, ok := q.db.(*sql.DB)
|
||||||
|
if !ok {
|
||||||
|
return fmt.Errorf("not *sql.DB")
|
||||||
|
}
|
||||||
|
tx, err := db.BeginTx(ctx, &sql.TxOptions{Isolation: sql.LevelReadCommitted})
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
defer tx.Rollback()
|
||||||
|
nq := q.WithTx(tx)
|
||||||
|
err = f(nq)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
err = tx.Commit()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
@ -1,7 +1,6 @@
|
|||||||
package handle
|
package handle
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"database/sql"
|
|
||||||
"errors"
|
"errors"
|
||||||
"log/slog"
|
"log/slog"
|
||||||
"net/http"
|
"net/http"
|
||||||
@ -16,7 +15,7 @@ import (
|
|||||||
"github.com/xmdhs/authlib-skin/utils"
|
"github.com/xmdhs/authlib-skin/utils"
|
||||||
)
|
)
|
||||||
|
|
||||||
func Reg(l *slog.Logger, q mysql.Querier, v *validator.Validate, db *sql.DB, snow *snowflake.Node, c config.Config) httprouter.Handle {
|
func Reg(l *slog.Logger, q mysql.QuerierWithTx, v *validator.Validate, snow *snowflake.Node, c config.Config) httprouter.Handle {
|
||||||
return func(w http.ResponseWriter, r *http.Request, p httprouter.Params) {
|
return func(w http.ResponseWriter, r *http.Request, p httprouter.Params) {
|
||||||
ctx := r.Context()
|
ctx := r.Context()
|
||||||
|
|
||||||
@ -26,7 +25,7 @@ func Reg(l *slog.Logger, q mysql.Querier, v *validator.Validate, db *sql.DB, sno
|
|||||||
handleError(ctx, w, err.Error(), model.ErrInput, 400)
|
handleError(ctx, w, err.Error(), model.ErrInput, 400)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
err = service.Reg(ctx, u, q, db, snow, c)
|
err = service.Reg(ctx, u, q, snow, c)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if errors.Is(err, service.ErrExistUser) {
|
if errors.Is(err, service.ErrExistUser) {
|
||||||
l.DebugContext(ctx, err.Error())
|
l.DebugContext(ctx, err.Error())
|
||||||
|
@ -49,7 +49,7 @@ func ProvideDB(c config.Config) (*sql.DB, func(), error) {
|
|||||||
return db, func() { db.Close() }, nil
|
return db, func() { db.Close() }, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func ProvideQuerier(ctx context.Context, db *sql.DB) (mysql.Querier, func(), error) {
|
func ProvideQuerier(ctx context.Context, db *sql.DB) (mysql.QuerierWithTx, func(), error) {
|
||||||
q, err := mysql.Prepare(ctx, db)
|
q, err := mysql.Prepare(ctx, db)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, nil, fmt.Errorf("newQuerier: %w", err)
|
return nil, nil, fmt.Errorf("newQuerier: %w", err)
|
||||||
|
@ -1,7 +1,6 @@
|
|||||||
package route
|
package route
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"database/sql"
|
|
||||||
"fmt"
|
"fmt"
|
||||||
"log/slog"
|
"log/slog"
|
||||||
|
|
||||||
@ -13,13 +12,13 @@ import (
|
|||||||
"github.com/xmdhs/authlib-skin/handle"
|
"github.com/xmdhs/authlib-skin/handle"
|
||||||
)
|
)
|
||||||
|
|
||||||
func NewRoute(l *slog.Logger, q mysql.Querier, v *validator.Validate, db *sql.DB, snow *snowflake.Node, c config.Config) (*httprouter.Router, error) {
|
func NewRoute(l *slog.Logger, q mysql.QuerierWithTx, v *validator.Validate, snow *snowflake.Node, c config.Config) (*httprouter.Router, error) {
|
||||||
r := httprouter.New()
|
r := httprouter.New()
|
||||||
err := newYggdrasil(r)
|
err := newYggdrasil(r)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("NewRoute: %w", err)
|
return nil, fmt.Errorf("NewRoute: %w", err)
|
||||||
}
|
}
|
||||||
err = newSkinApi(r, l, q, v, db, snow, c)
|
err = newSkinApi(r, l, q, v, snow, c)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("NewRoute: %w", err)
|
return nil, fmt.Errorf("NewRoute: %w", err)
|
||||||
}
|
}
|
||||||
@ -31,7 +30,7 @@ func newYggdrasil(r *httprouter.Router) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func newSkinApi(r *httprouter.Router, l *slog.Logger, q mysql.Querier, v *validator.Validate, db *sql.DB, snow *snowflake.Node, c config.Config) error {
|
func newSkinApi(r *httprouter.Router, l *slog.Logger, q mysql.QuerierWithTx, v *validator.Validate, snow *snowflake.Node, c config.Config) error {
|
||||||
r.PUT("/api/v1/user/reg", handle.Reg(l, q, v, db, snow, c))
|
r.PUT("/api/v1/user/reg", handle.Reg(l, q, v, snow, c))
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
@ -26,7 +26,7 @@ func InitializeRoute(ctx context.Context, c config.Config) (*http.Server, func()
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, nil, err
|
return nil, nil, err
|
||||||
}
|
}
|
||||||
querier, cleanup2, err := ProvideQuerier(ctx, db)
|
querierWithTx, cleanup2, err := ProvideQuerier(ctx, db)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
cleanup()
|
cleanup()
|
||||||
return nil, nil, err
|
return nil, nil, err
|
||||||
@ -38,7 +38,7 @@ func InitializeRoute(ctx context.Context, c config.Config) (*http.Server, func()
|
|||||||
cleanup()
|
cleanup()
|
||||||
return nil, nil, err
|
return nil, nil, err
|
||||||
}
|
}
|
||||||
router, err := route.NewRoute(logger, querier, validate, db, node, c)
|
router, err := route.NewRoute(logger, querierWithTx, validate, node, c)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
cleanup2()
|
cleanup2()
|
||||||
cleanup()
|
cleanup()
|
||||||
|
@ -23,7 +23,7 @@ var (
|
|||||||
ErrExitsName = errors.New("用户名已存在")
|
ErrExitsName = errors.New("用户名已存在")
|
||||||
)
|
)
|
||||||
|
|
||||||
func Reg(ctx context.Context, u model.User, q mysql.Querier, db *sql.DB, snow *snowflake.Node,
|
func Reg(ctx context.Context, u model.User, q mysql.QuerierWithTx, snow *snowflake.Node,
|
||||||
c config.Config,
|
c config.Config,
|
||||||
) error {
|
) error {
|
||||||
ou, err := q.GetUserByEmail(ctx, u.Email)
|
ou, err := q.GetUserByEmail(ctx, u.Email)
|
||||||
@ -33,9 +33,7 @@ func Reg(ctx context.Context, u model.User, q mysql.Querier, db *sql.DB, snow *s
|
|||||||
if ou.Email != "" {
|
if ou.Email != "" {
|
||||||
return fmt.Errorf("Reg: %w", ErrExistUser)
|
return fmt.Errorf("Reg: %w", ErrExistUser)
|
||||||
}
|
}
|
||||||
err = utils.WithTx(ctx, &sql.TxOptions{
|
err = q.Tx(ctx, func(q mysql.Querier) error {
|
||||||
Isolation: sql.LevelReadCommitted,
|
|
||||||
}, q, db, func(q mysql.Querier) error {
|
|
||||||
p, s := utils.Argon2ID(u.Password)
|
p, s := utils.Argon2ID(u.Password)
|
||||||
userID := snow.Generate().Int64()
|
userID := snow.Generate().Int64()
|
||||||
_, err := q.CreateUser(ctx, mysql.CreateUserParams{
|
_, err := q.CreateUser(ctx, mysql.CreateUserParams{
|
||||||
|
60
utils/tx.go
60
utils/tx.go
@ -1,36 +1,28 @@
|
|||||||
package utils
|
package utils
|
||||||
|
|
||||||
import (
|
// func WithTx(ctx context.Context, opts *sql.TxOptions, q mysql.Querier, db *sql.DB, f func(mysql.Querier) error) error {
|
||||||
"context"
|
// w, ok := q.(interface {
|
||||||
"database/sql"
|
// WithTx(tx *sql.Tx) *mysql.Queries
|
||||||
"fmt"
|
// })
|
||||||
|
// var tx *sql.Tx
|
||||||
"github.com/xmdhs/authlib-skin/db/mysql"
|
// if ok {
|
||||||
)
|
// var err error
|
||||||
|
// tx, err = db.BeginTx(ctx, opts)
|
||||||
func WithTx(ctx context.Context, opts *sql.TxOptions, q mysql.Querier, db *sql.DB, f func(mysql.Querier) error) error {
|
// if err != nil {
|
||||||
w, ok := q.(interface {
|
// return fmt.Errorf("WithTx: %w", err)
|
||||||
WithTx(tx *sql.Tx) *mysql.Queries
|
// }
|
||||||
})
|
// defer tx.Rollback()
|
||||||
var tx *sql.Tx
|
// q = w.WithTx(tx)
|
||||||
if ok {
|
// }
|
||||||
var err error
|
// err := f(q)
|
||||||
tx, err = db.BeginTx(ctx, opts)
|
// if err != nil {
|
||||||
if err != nil {
|
// return fmt.Errorf("WithTx: %w", err)
|
||||||
return fmt.Errorf("WithTx: %w", err)
|
// }
|
||||||
}
|
// if tx != nil {
|
||||||
defer tx.Rollback()
|
// err := tx.Commit()
|
||||||
q = w.WithTx(tx)
|
// if err != nil {
|
||||||
}
|
// return fmt.Errorf("WithTx: %w", err)
|
||||||
err := f(q)
|
// }
|
||||||
if err != nil {
|
// }
|
||||||
return fmt.Errorf("WithTx: %w", err)
|
// return nil
|
||||||
}
|
// }
|
||||||
if tx != nil {
|
|
||||||
err := tx.Commit()
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("WithTx: %w", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user