更改事务方式
This commit is contained in:
parent
c35e71d22c
commit
280b0ad839
@ -1,2 +1,37 @@
|
||||
//go:generate sqlc generate
|
||||
package mysql
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
)
|
||||
|
||||
type QuerierWithTx interface {
|
||||
Querier
|
||||
Tx(ctx context.Context, f func(Querier) error) error
|
||||
}
|
||||
|
||||
var _ QuerierWithTx = (*Queries)(nil)
|
||||
|
||||
func (q *Queries) Tx(ctx context.Context, f func(Querier) error) error {
|
||||
db, ok := q.db.(*sql.DB)
|
||||
if !ok {
|
||||
return fmt.Errorf("not *sql.DB")
|
||||
}
|
||||
tx, err := db.BeginTx(ctx, &sql.TxOptions{Isolation: sql.LevelReadCommitted})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer tx.Rollback()
|
||||
nq := q.WithTx(tx)
|
||||
err = f(nq)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
err = tx.Commit()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
@ -1,7 +1,6 @@
|
||||
package handle
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"errors"
|
||||
"log/slog"
|
||||
"net/http"
|
||||
@ -16,7 +15,7 @@ import (
|
||||
"github.com/xmdhs/authlib-skin/utils"
|
||||
)
|
||||
|
||||
func Reg(l *slog.Logger, q mysql.Querier, v *validator.Validate, db *sql.DB, snow *snowflake.Node, c config.Config) httprouter.Handle {
|
||||
func Reg(l *slog.Logger, q mysql.QuerierWithTx, v *validator.Validate, snow *snowflake.Node, c config.Config) httprouter.Handle {
|
||||
return func(w http.ResponseWriter, r *http.Request, p httprouter.Params) {
|
||||
ctx := r.Context()
|
||||
|
||||
@ -26,7 +25,7 @@ func Reg(l *slog.Logger, q mysql.Querier, v *validator.Validate, db *sql.DB, sno
|
||||
handleError(ctx, w, err.Error(), model.ErrInput, 400)
|
||||
return
|
||||
}
|
||||
err = service.Reg(ctx, u, q, db, snow, c)
|
||||
err = service.Reg(ctx, u, q, snow, c)
|
||||
if err != nil {
|
||||
if errors.Is(err, service.ErrExistUser) {
|
||||
l.DebugContext(ctx, err.Error())
|
||||
|
@ -49,7 +49,7 @@ func ProvideDB(c config.Config) (*sql.DB, func(), error) {
|
||||
return db, func() { db.Close() }, nil
|
||||
}
|
||||
|
||||
func ProvideQuerier(ctx context.Context, db *sql.DB) (mysql.Querier, func(), error) {
|
||||
func ProvideQuerier(ctx context.Context, db *sql.DB) (mysql.QuerierWithTx, func(), error) {
|
||||
q, err := mysql.Prepare(ctx, db)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("newQuerier: %w", err)
|
||||
|
@ -1,7 +1,6 @@
|
||||
package route
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
|
||||
@ -13,13 +12,13 @@ import (
|
||||
"github.com/xmdhs/authlib-skin/handle"
|
||||
)
|
||||
|
||||
func NewRoute(l *slog.Logger, q mysql.Querier, v *validator.Validate, db *sql.DB, snow *snowflake.Node, c config.Config) (*httprouter.Router, error) {
|
||||
func NewRoute(l *slog.Logger, q mysql.QuerierWithTx, v *validator.Validate, snow *snowflake.Node, c config.Config) (*httprouter.Router, error) {
|
||||
r := httprouter.New()
|
||||
err := newYggdrasil(r)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("NewRoute: %w", err)
|
||||
}
|
||||
err = newSkinApi(r, l, q, v, db, snow, c)
|
||||
err = newSkinApi(r, l, q, v, snow, c)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("NewRoute: %w", err)
|
||||
}
|
||||
@ -31,7 +30,7 @@ func newYggdrasil(r *httprouter.Router) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func newSkinApi(r *httprouter.Router, l *slog.Logger, q mysql.Querier, v *validator.Validate, db *sql.DB, snow *snowflake.Node, c config.Config) error {
|
||||
r.PUT("/api/v1/user/reg", handle.Reg(l, q, v, db, snow, c))
|
||||
func newSkinApi(r *httprouter.Router, l *slog.Logger, q mysql.QuerierWithTx, v *validator.Validate, snow *snowflake.Node, c config.Config) error {
|
||||
r.PUT("/api/v1/user/reg", handle.Reg(l, q, v, snow, c))
|
||||
return nil
|
||||
}
|
||||
|
@ -26,7 +26,7 @@ func InitializeRoute(ctx context.Context, c config.Config) (*http.Server, func()
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
querier, cleanup2, err := ProvideQuerier(ctx, db)
|
||||
querierWithTx, cleanup2, err := ProvideQuerier(ctx, db)
|
||||
if err != nil {
|
||||
cleanup()
|
||||
return nil, nil, err
|
||||
@ -38,7 +38,7 @@ func InitializeRoute(ctx context.Context, c config.Config) (*http.Server, func()
|
||||
cleanup()
|
||||
return nil, nil, err
|
||||
}
|
||||
router, err := route.NewRoute(logger, querier, validate, db, node, c)
|
||||
router, err := route.NewRoute(logger, querierWithTx, validate, node, c)
|
||||
if err != nil {
|
||||
cleanup2()
|
||||
cleanup()
|
||||
|
@ -23,7 +23,7 @@ var (
|
||||
ErrExitsName = errors.New("用户名已存在")
|
||||
)
|
||||
|
||||
func Reg(ctx context.Context, u model.User, q mysql.Querier, db *sql.DB, snow *snowflake.Node,
|
||||
func Reg(ctx context.Context, u model.User, q mysql.QuerierWithTx, snow *snowflake.Node,
|
||||
c config.Config,
|
||||
) error {
|
||||
ou, err := q.GetUserByEmail(ctx, u.Email)
|
||||
@ -33,9 +33,7 @@ func Reg(ctx context.Context, u model.User, q mysql.Querier, db *sql.DB, snow *s
|
||||
if ou.Email != "" {
|
||||
return fmt.Errorf("Reg: %w", ErrExistUser)
|
||||
}
|
||||
err = utils.WithTx(ctx, &sql.TxOptions{
|
||||
Isolation: sql.LevelReadCommitted,
|
||||
}, q, db, func(q mysql.Querier) error {
|
||||
err = q.Tx(ctx, func(q mysql.Querier) error {
|
||||
p, s := utils.Argon2ID(u.Password)
|
||||
userID := snow.Generate().Int64()
|
||||
_, err := q.CreateUser(ctx, mysql.CreateUserParams{
|
||||
|
60
utils/tx.go
60
utils/tx.go
@ -1,36 +1,28 @@
|
||||
package utils
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
|
||||
"github.com/xmdhs/authlib-skin/db/mysql"
|
||||
)
|
||||
|
||||
func WithTx(ctx context.Context, opts *sql.TxOptions, q mysql.Querier, db *sql.DB, f func(mysql.Querier) error) error {
|
||||
w, ok := q.(interface {
|
||||
WithTx(tx *sql.Tx) *mysql.Queries
|
||||
})
|
||||
var tx *sql.Tx
|
||||
if ok {
|
||||
var err error
|
||||
tx, err = db.BeginTx(ctx, opts)
|
||||
if err != nil {
|
||||
return fmt.Errorf("WithTx: %w", err)
|
||||
}
|
||||
defer tx.Rollback()
|
||||
q = w.WithTx(tx)
|
||||
}
|
||||
err := f(q)
|
||||
if err != nil {
|
||||
return fmt.Errorf("WithTx: %w", err)
|
||||
}
|
||||
if tx != nil {
|
||||
err := tx.Commit()
|
||||
if err != nil {
|
||||
return fmt.Errorf("WithTx: %w", err)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
// func WithTx(ctx context.Context, opts *sql.TxOptions, q mysql.Querier, db *sql.DB, f func(mysql.Querier) error) error {
|
||||
// w, ok := q.(interface {
|
||||
// WithTx(tx *sql.Tx) *mysql.Queries
|
||||
// })
|
||||
// var tx *sql.Tx
|
||||
// if ok {
|
||||
// var err error
|
||||
// tx, err = db.BeginTx(ctx, opts)
|
||||
// if err != nil {
|
||||
// return fmt.Errorf("WithTx: %w", err)
|
||||
// }
|
||||
// defer tx.Rollback()
|
||||
// q = w.WithTx(tx)
|
||||
// }
|
||||
// err := f(q)
|
||||
// if err != nil {
|
||||
// return fmt.Errorf("WithTx: %w", err)
|
||||
// }
|
||||
// if tx != nil {
|
||||
// err := tx.Commit()
|
||||
// if err != nil {
|
||||
// return fmt.Errorf("WithTx: %w", err)
|
||||
// }
|
||||
// }
|
||||
// return nil
|
||||
// }
|
||||
|
Loading…
x
Reference in New Issue
Block a user