更改事务方式

This commit is contained in:
xmdhs 2023-09-02 22:35:53 +08:00
parent c35e71d22c
commit 280b0ad839
No known key found for this signature in database
GPG Key ID: E809D6D43DEFCC95
7 changed files with 72 additions and 49 deletions

View File

@ -1,2 +1,37 @@
//go:generate sqlc generate
package mysql
import (
"context"
"database/sql"
"fmt"
)
type QuerierWithTx interface {
Querier
Tx(ctx context.Context, f func(Querier) error) error
}
var _ QuerierWithTx = (*Queries)(nil)
func (q *Queries) Tx(ctx context.Context, f func(Querier) error) error {
db, ok := q.db.(*sql.DB)
if !ok {
return fmt.Errorf("not *sql.DB")
}
tx, err := db.BeginTx(ctx, &sql.TxOptions{Isolation: sql.LevelReadCommitted})
if err != nil {
return err
}
defer tx.Rollback()
nq := q.WithTx(tx)
err = f(nq)
if err != nil {
return err
}
err = tx.Commit()
if err != nil {
return err
}
return nil
}

View File

@ -1,7 +1,6 @@
package handle
import (
"database/sql"
"errors"
"log/slog"
"net/http"
@ -16,7 +15,7 @@ import (
"github.com/xmdhs/authlib-skin/utils"
)
func Reg(l *slog.Logger, q mysql.Querier, v *validator.Validate, db *sql.DB, snow *snowflake.Node, c config.Config) httprouter.Handle {
func Reg(l *slog.Logger, q mysql.QuerierWithTx, v *validator.Validate, snow *snowflake.Node, c config.Config) httprouter.Handle {
return func(w http.ResponseWriter, r *http.Request, p httprouter.Params) {
ctx := r.Context()
@ -26,7 +25,7 @@ func Reg(l *slog.Logger, q mysql.Querier, v *validator.Validate, db *sql.DB, sno
handleError(ctx, w, err.Error(), model.ErrInput, 400)
return
}
err = service.Reg(ctx, u, q, db, snow, c)
err = service.Reg(ctx, u, q, snow, c)
if err != nil {
if errors.Is(err, service.ErrExistUser) {
l.DebugContext(ctx, err.Error())

View File

@ -49,7 +49,7 @@ func ProvideDB(c config.Config) (*sql.DB, func(), error) {
return db, func() { db.Close() }, nil
}
func ProvideQuerier(ctx context.Context, db *sql.DB) (mysql.Querier, func(), error) {
func ProvideQuerier(ctx context.Context, db *sql.DB) (mysql.QuerierWithTx, func(), error) {
q, err := mysql.Prepare(ctx, db)
if err != nil {
return nil, nil, fmt.Errorf("newQuerier: %w", err)

View File

@ -1,7 +1,6 @@
package route
import (
"database/sql"
"fmt"
"log/slog"
@ -13,13 +12,13 @@ import (
"github.com/xmdhs/authlib-skin/handle"
)
func NewRoute(l *slog.Logger, q mysql.Querier, v *validator.Validate, db *sql.DB, snow *snowflake.Node, c config.Config) (*httprouter.Router, error) {
func NewRoute(l *slog.Logger, q mysql.QuerierWithTx, v *validator.Validate, snow *snowflake.Node, c config.Config) (*httprouter.Router, error) {
r := httprouter.New()
err := newYggdrasil(r)
if err != nil {
return nil, fmt.Errorf("NewRoute: %w", err)
}
err = newSkinApi(r, l, q, v, db, snow, c)
err = newSkinApi(r, l, q, v, snow, c)
if err != nil {
return nil, fmt.Errorf("NewRoute: %w", err)
}
@ -31,7 +30,7 @@ func newYggdrasil(r *httprouter.Router) error {
return nil
}
func newSkinApi(r *httprouter.Router, l *slog.Logger, q mysql.Querier, v *validator.Validate, db *sql.DB, snow *snowflake.Node, c config.Config) error {
r.PUT("/api/v1/user/reg", handle.Reg(l, q, v, db, snow, c))
func newSkinApi(r *httprouter.Router, l *slog.Logger, q mysql.QuerierWithTx, v *validator.Validate, snow *snowflake.Node, c config.Config) error {
r.PUT("/api/v1/user/reg", handle.Reg(l, q, v, snow, c))
return nil
}

View File

@ -26,7 +26,7 @@ func InitializeRoute(ctx context.Context, c config.Config) (*http.Server, func()
if err != nil {
return nil, nil, err
}
querier, cleanup2, err := ProvideQuerier(ctx, db)
querierWithTx, cleanup2, err := ProvideQuerier(ctx, db)
if err != nil {
cleanup()
return nil, nil, err
@ -38,7 +38,7 @@ func InitializeRoute(ctx context.Context, c config.Config) (*http.Server, func()
cleanup()
return nil, nil, err
}
router, err := route.NewRoute(logger, querier, validate, db, node, c)
router, err := route.NewRoute(logger, querierWithTx, validate, node, c)
if err != nil {
cleanup2()
cleanup()

View File

@ -23,7 +23,7 @@ var (
ErrExitsName = errors.New("用户名已存在")
)
func Reg(ctx context.Context, u model.User, q mysql.Querier, db *sql.DB, snow *snowflake.Node,
func Reg(ctx context.Context, u model.User, q mysql.QuerierWithTx, snow *snowflake.Node,
c config.Config,
) error {
ou, err := q.GetUserByEmail(ctx, u.Email)
@ -33,9 +33,7 @@ func Reg(ctx context.Context, u model.User, q mysql.Querier, db *sql.DB, snow *s
if ou.Email != "" {
return fmt.Errorf("Reg: %w", ErrExistUser)
}
err = utils.WithTx(ctx, &sql.TxOptions{
Isolation: sql.LevelReadCommitted,
}, q, db, func(q mysql.Querier) error {
err = q.Tx(ctx, func(q mysql.Querier) error {
p, s := utils.Argon2ID(u.Password)
userID := snow.Generate().Int64()
_, err := q.CreateUser(ctx, mysql.CreateUserParams{

View File

@ -1,36 +1,28 @@
package utils
import (
"context"
"database/sql"
"fmt"
"github.com/xmdhs/authlib-skin/db/mysql"
)
func WithTx(ctx context.Context, opts *sql.TxOptions, q mysql.Querier, db *sql.DB, f func(mysql.Querier) error) error {
w, ok := q.(interface {
WithTx(tx *sql.Tx) *mysql.Queries
})
var tx *sql.Tx
if ok {
var err error
tx, err = db.BeginTx(ctx, opts)
if err != nil {
return fmt.Errorf("WithTx: %w", err)
}
defer tx.Rollback()
q = w.WithTx(tx)
}
err := f(q)
if err != nil {
return fmt.Errorf("WithTx: %w", err)
}
if tx != nil {
err := tx.Commit()
if err != nil {
return fmt.Errorf("WithTx: %w", err)
}
}
return nil
}
// func WithTx(ctx context.Context, opts *sql.TxOptions, q mysql.Querier, db *sql.DB, f func(mysql.Querier) error) error {
// w, ok := q.(interface {
// WithTx(tx *sql.Tx) *mysql.Queries
// })
// var tx *sql.Tx
// if ok {
// var err error
// tx, err = db.BeginTx(ctx, opts)
// if err != nil {
// return fmt.Errorf("WithTx: %w", err)
// }
// defer tx.Rollback()
// q = w.WithTx(tx)
// }
// err := f(q)
// if err != nil {
// return fmt.Errorf("WithTx: %w", err)
// }
// if tx != nil {
// err := tx.Commit()
// if err != nil {
// return fmt.Errorf("WithTx: %w", err)
// }
// }
// return nil
// }