更改事务方式

This commit is contained in:
xmdhs 2023-09-02 22:35:53 +08:00
parent c35e71d22c
commit 280b0ad839
No known key found for this signature in database
GPG Key ID: E809D6D43DEFCC95
7 changed files with 72 additions and 49 deletions

View File

@ -1,2 +1,37 @@
//go:generate sqlc generate //go:generate sqlc generate
package mysql package mysql
import (
"context"
"database/sql"
"fmt"
)
type QuerierWithTx interface {
Querier
Tx(ctx context.Context, f func(Querier) error) error
}
var _ QuerierWithTx = (*Queries)(nil)
func (q *Queries) Tx(ctx context.Context, f func(Querier) error) error {
db, ok := q.db.(*sql.DB)
if !ok {
return fmt.Errorf("not *sql.DB")
}
tx, err := db.BeginTx(ctx, &sql.TxOptions{Isolation: sql.LevelReadCommitted})
if err != nil {
return err
}
defer tx.Rollback()
nq := q.WithTx(tx)
err = f(nq)
if err != nil {
return err
}
err = tx.Commit()
if err != nil {
return err
}
return nil
}

View File

@ -1,7 +1,6 @@
package handle package handle
import ( import (
"database/sql"
"errors" "errors"
"log/slog" "log/slog"
"net/http" "net/http"
@ -16,7 +15,7 @@ import (
"github.com/xmdhs/authlib-skin/utils" "github.com/xmdhs/authlib-skin/utils"
) )
func Reg(l *slog.Logger, q mysql.Querier, v *validator.Validate, db *sql.DB, snow *snowflake.Node, c config.Config) httprouter.Handle { func Reg(l *slog.Logger, q mysql.QuerierWithTx, v *validator.Validate, snow *snowflake.Node, c config.Config) httprouter.Handle {
return func(w http.ResponseWriter, r *http.Request, p httprouter.Params) { return func(w http.ResponseWriter, r *http.Request, p httprouter.Params) {
ctx := r.Context() ctx := r.Context()
@ -26,7 +25,7 @@ func Reg(l *slog.Logger, q mysql.Querier, v *validator.Validate, db *sql.DB, sno
handleError(ctx, w, err.Error(), model.ErrInput, 400) handleError(ctx, w, err.Error(), model.ErrInput, 400)
return return
} }
err = service.Reg(ctx, u, q, db, snow, c) err = service.Reg(ctx, u, q, snow, c)
if err != nil { if err != nil {
if errors.Is(err, service.ErrExistUser) { if errors.Is(err, service.ErrExistUser) {
l.DebugContext(ctx, err.Error()) l.DebugContext(ctx, err.Error())

View File

@ -49,7 +49,7 @@ func ProvideDB(c config.Config) (*sql.DB, func(), error) {
return db, func() { db.Close() }, nil return db, func() { db.Close() }, nil
} }
func ProvideQuerier(ctx context.Context, db *sql.DB) (mysql.Querier, func(), error) { func ProvideQuerier(ctx context.Context, db *sql.DB) (mysql.QuerierWithTx, func(), error) {
q, err := mysql.Prepare(ctx, db) q, err := mysql.Prepare(ctx, db)
if err != nil { if err != nil {
return nil, nil, fmt.Errorf("newQuerier: %w", err) return nil, nil, fmt.Errorf("newQuerier: %w", err)

View File

@ -1,7 +1,6 @@
package route package route
import ( import (
"database/sql"
"fmt" "fmt"
"log/slog" "log/slog"
@ -13,13 +12,13 @@ import (
"github.com/xmdhs/authlib-skin/handle" "github.com/xmdhs/authlib-skin/handle"
) )
func NewRoute(l *slog.Logger, q mysql.Querier, v *validator.Validate, db *sql.DB, snow *snowflake.Node, c config.Config) (*httprouter.Router, error) { func NewRoute(l *slog.Logger, q mysql.QuerierWithTx, v *validator.Validate, snow *snowflake.Node, c config.Config) (*httprouter.Router, error) {
r := httprouter.New() r := httprouter.New()
err := newYggdrasil(r) err := newYggdrasil(r)
if err != nil { if err != nil {
return nil, fmt.Errorf("NewRoute: %w", err) return nil, fmt.Errorf("NewRoute: %w", err)
} }
err = newSkinApi(r, l, q, v, db, snow, c) err = newSkinApi(r, l, q, v, snow, c)
if err != nil { if err != nil {
return nil, fmt.Errorf("NewRoute: %w", err) return nil, fmt.Errorf("NewRoute: %w", err)
} }
@ -31,7 +30,7 @@ func newYggdrasil(r *httprouter.Router) error {
return nil return nil
} }
func newSkinApi(r *httprouter.Router, l *slog.Logger, q mysql.Querier, v *validator.Validate, db *sql.DB, snow *snowflake.Node, c config.Config) error { func newSkinApi(r *httprouter.Router, l *slog.Logger, q mysql.QuerierWithTx, v *validator.Validate, snow *snowflake.Node, c config.Config) error {
r.PUT("/api/v1/user/reg", handle.Reg(l, q, v, db, snow, c)) r.PUT("/api/v1/user/reg", handle.Reg(l, q, v, snow, c))
return nil return nil
} }

View File

@ -26,7 +26,7 @@ func InitializeRoute(ctx context.Context, c config.Config) (*http.Server, func()
if err != nil { if err != nil {
return nil, nil, err return nil, nil, err
} }
querier, cleanup2, err := ProvideQuerier(ctx, db) querierWithTx, cleanup2, err := ProvideQuerier(ctx, db)
if err != nil { if err != nil {
cleanup() cleanup()
return nil, nil, err return nil, nil, err
@ -38,7 +38,7 @@ func InitializeRoute(ctx context.Context, c config.Config) (*http.Server, func()
cleanup() cleanup()
return nil, nil, err return nil, nil, err
} }
router, err := route.NewRoute(logger, querier, validate, db, node, c) router, err := route.NewRoute(logger, querierWithTx, validate, node, c)
if err != nil { if err != nil {
cleanup2() cleanup2()
cleanup() cleanup()

View File

@ -23,7 +23,7 @@ var (
ErrExitsName = errors.New("用户名已存在") ErrExitsName = errors.New("用户名已存在")
) )
func Reg(ctx context.Context, u model.User, q mysql.Querier, db *sql.DB, snow *snowflake.Node, func Reg(ctx context.Context, u model.User, q mysql.QuerierWithTx, snow *snowflake.Node,
c config.Config, c config.Config,
) error { ) error {
ou, err := q.GetUserByEmail(ctx, u.Email) ou, err := q.GetUserByEmail(ctx, u.Email)
@ -33,9 +33,7 @@ func Reg(ctx context.Context, u model.User, q mysql.Querier, db *sql.DB, snow *s
if ou.Email != "" { if ou.Email != "" {
return fmt.Errorf("Reg: %w", ErrExistUser) return fmt.Errorf("Reg: %w", ErrExistUser)
} }
err = utils.WithTx(ctx, &sql.TxOptions{ err = q.Tx(ctx, func(q mysql.Querier) error {
Isolation: sql.LevelReadCommitted,
}, q, db, func(q mysql.Querier) error {
p, s := utils.Argon2ID(u.Password) p, s := utils.Argon2ID(u.Password)
userID := snow.Generate().Int64() userID := snow.Generate().Int64()
_, err := q.CreateUser(ctx, mysql.CreateUserParams{ _, err := q.CreateUser(ctx, mysql.CreateUserParams{

View File

@ -1,36 +1,28 @@
package utils package utils
import ( // func WithTx(ctx context.Context, opts *sql.TxOptions, q mysql.Querier, db *sql.DB, f func(mysql.Querier) error) error {
"context" // w, ok := q.(interface {
"database/sql" // WithTx(tx *sql.Tx) *mysql.Queries
"fmt" // })
// var tx *sql.Tx
"github.com/xmdhs/authlib-skin/db/mysql" // if ok {
) // var err error
// tx, err = db.BeginTx(ctx, opts)
func WithTx(ctx context.Context, opts *sql.TxOptions, q mysql.Querier, db *sql.DB, f func(mysql.Querier) error) error { // if err != nil {
w, ok := q.(interface { // return fmt.Errorf("WithTx: %w", err)
WithTx(tx *sql.Tx) *mysql.Queries // }
}) // defer tx.Rollback()
var tx *sql.Tx // q = w.WithTx(tx)
if ok { // }
var err error // err := f(q)
tx, err = db.BeginTx(ctx, opts) // if err != nil {
if err != nil { // return fmt.Errorf("WithTx: %w", err)
return fmt.Errorf("WithTx: %w", err) // }
} // if tx != nil {
defer tx.Rollback() // err := tx.Commit()
q = w.WithTx(tx) // if err != nil {
} // return fmt.Errorf("WithTx: %w", err)
err := f(q) // }
if err != nil { // }
return fmt.Errorf("WithTx: %w", err) // return nil
} // }
if tx != nil {
err := tx.Commit()
if err != nil {
return fmt.Errorf("WithTx: %w", err)
}
}
return nil
}