diff --git a/db/mysql/sqlc.go b/db/mysql/sqlc.go index a64604b..b40f15c 100644 --- a/db/mysql/sqlc.go +++ b/db/mysql/sqlc.go @@ -1,2 +1,37 @@ //go:generate sqlc generate package mysql + +import ( + "context" + "database/sql" + "fmt" +) + +type QuerierWithTx interface { + Querier + Tx(ctx context.Context, f func(Querier) error) error +} + +var _ QuerierWithTx = (*Queries)(nil) + +func (q *Queries) Tx(ctx context.Context, f func(Querier) error) error { + db, ok := q.db.(*sql.DB) + if !ok { + return fmt.Errorf("not *sql.DB") + } + tx, err := db.BeginTx(ctx, &sql.TxOptions{Isolation: sql.LevelReadCommitted}) + if err != nil { + return err + } + defer tx.Rollback() + nq := q.WithTx(tx) + err = f(nq) + if err != nil { + return err + } + err = tx.Commit() + if err != nil { + return err + } + return nil +} diff --git a/handle/user.go b/handle/user.go index 727c7f5..c51bf93 100644 --- a/handle/user.go +++ b/handle/user.go @@ -1,7 +1,6 @@ package handle import ( - "database/sql" "errors" "log/slog" "net/http" @@ -16,7 +15,7 @@ import ( "github.com/xmdhs/authlib-skin/utils" ) -func Reg(l *slog.Logger, q mysql.Querier, v *validator.Validate, db *sql.DB, snow *snowflake.Node, c config.Config) httprouter.Handle { +func Reg(l *slog.Logger, q mysql.QuerierWithTx, v *validator.Validate, snow *snowflake.Node, c config.Config) httprouter.Handle { return func(w http.ResponseWriter, r *http.Request, p httprouter.Params) { ctx := r.Context() @@ -26,7 +25,7 @@ func Reg(l *slog.Logger, q mysql.Querier, v *validator.Validate, db *sql.DB, sno handleError(ctx, w, err.Error(), model.ErrInput, 400) return } - err = service.Reg(ctx, u, q, db, snow, c) + err = service.Reg(ctx, u, q, snow, c) if err != nil { if errors.Is(err, service.ErrExistUser) { l.DebugContext(ctx, err.Error()) diff --git a/server/provide.go b/server/provide.go index 81986f0..d77a1c4 100644 --- a/server/provide.go +++ b/server/provide.go @@ -49,7 +49,7 @@ func ProvideDB(c config.Config) (*sql.DB, func(), error) { return db, func() { db.Close() }, nil } -func ProvideQuerier(ctx context.Context, db *sql.DB) (mysql.Querier, func(), error) { +func ProvideQuerier(ctx context.Context, db *sql.DB) (mysql.QuerierWithTx, func(), error) { q, err := mysql.Prepare(ctx, db) if err != nil { return nil, nil, fmt.Errorf("newQuerier: %w", err) diff --git a/server/route/route.go b/server/route/route.go index 182c62b..e9f17f8 100644 --- a/server/route/route.go +++ b/server/route/route.go @@ -1,7 +1,6 @@ package route import ( - "database/sql" "fmt" "log/slog" @@ -13,13 +12,13 @@ import ( "github.com/xmdhs/authlib-skin/handle" ) -func NewRoute(l *slog.Logger, q mysql.Querier, v *validator.Validate, db *sql.DB, snow *snowflake.Node, c config.Config) (*httprouter.Router, error) { +func NewRoute(l *slog.Logger, q mysql.QuerierWithTx, v *validator.Validate, snow *snowflake.Node, c config.Config) (*httprouter.Router, error) { r := httprouter.New() err := newYggdrasil(r) if err != nil { return nil, fmt.Errorf("NewRoute: %w", err) } - err = newSkinApi(r, l, q, v, db, snow, c) + err = newSkinApi(r, l, q, v, snow, c) if err != nil { return nil, fmt.Errorf("NewRoute: %w", err) } @@ -31,7 +30,7 @@ func newYggdrasil(r *httprouter.Router) error { return nil } -func newSkinApi(r *httprouter.Router, l *slog.Logger, q mysql.Querier, v *validator.Validate, db *sql.DB, snow *snowflake.Node, c config.Config) error { - r.PUT("/api/v1/user/reg", handle.Reg(l, q, v, db, snow, c)) +func newSkinApi(r *httprouter.Router, l *slog.Logger, q mysql.QuerierWithTx, v *validator.Validate, snow *snowflake.Node, c config.Config) error { + r.PUT("/api/v1/user/reg", handle.Reg(l, q, v, snow, c)) return nil } diff --git a/server/wire_gen.go b/server/wire_gen.go index e9d911b..81b4c51 100644 --- a/server/wire_gen.go +++ b/server/wire_gen.go @@ -26,7 +26,7 @@ func InitializeRoute(ctx context.Context, c config.Config) (*http.Server, func() if err != nil { return nil, nil, err } - querier, cleanup2, err := ProvideQuerier(ctx, db) + querierWithTx, cleanup2, err := ProvideQuerier(ctx, db) if err != nil { cleanup() return nil, nil, err @@ -38,7 +38,7 @@ func InitializeRoute(ctx context.Context, c config.Config) (*http.Server, func() cleanup() return nil, nil, err } - router, err := route.NewRoute(logger, querier, validate, db, node, c) + router, err := route.NewRoute(logger, querierWithTx, validate, node, c) if err != nil { cleanup2() cleanup() diff --git a/service/user.go b/service/user.go index 82659a6..be99d73 100644 --- a/service/user.go +++ b/service/user.go @@ -23,7 +23,7 @@ var ( ErrExitsName = errors.New("用户名已存在") ) -func Reg(ctx context.Context, u model.User, q mysql.Querier, db *sql.DB, snow *snowflake.Node, +func Reg(ctx context.Context, u model.User, q mysql.QuerierWithTx, snow *snowflake.Node, c config.Config, ) error { ou, err := q.GetUserByEmail(ctx, u.Email) @@ -33,9 +33,7 @@ func Reg(ctx context.Context, u model.User, q mysql.Querier, db *sql.DB, snow *s if ou.Email != "" { return fmt.Errorf("Reg: %w", ErrExistUser) } - err = utils.WithTx(ctx, &sql.TxOptions{ - Isolation: sql.LevelReadCommitted, - }, q, db, func(q mysql.Querier) error { + err = q.Tx(ctx, func(q mysql.Querier) error { p, s := utils.Argon2ID(u.Password) userID := snow.Generate().Int64() _, err := q.CreateUser(ctx, mysql.CreateUserParams{ diff --git a/utils/tx.go b/utils/tx.go index a3f1b64..9359906 100644 --- a/utils/tx.go +++ b/utils/tx.go @@ -1,36 +1,28 @@ package utils -import ( - "context" - "database/sql" - "fmt" - - "github.com/xmdhs/authlib-skin/db/mysql" -) - -func WithTx(ctx context.Context, opts *sql.TxOptions, q mysql.Querier, db *sql.DB, f func(mysql.Querier) error) error { - w, ok := q.(interface { - WithTx(tx *sql.Tx) *mysql.Queries - }) - var tx *sql.Tx - if ok { - var err error - tx, err = db.BeginTx(ctx, opts) - if err != nil { - return fmt.Errorf("WithTx: %w", err) - } - defer tx.Rollback() - q = w.WithTx(tx) - } - err := f(q) - if err != nil { - return fmt.Errorf("WithTx: %w", err) - } - if tx != nil { - err := tx.Commit() - if err != nil { - return fmt.Errorf("WithTx: %w", err) - } - } - return nil -} +// func WithTx(ctx context.Context, opts *sql.TxOptions, q mysql.Querier, db *sql.DB, f func(mysql.Querier) error) error { +// w, ok := q.(interface { +// WithTx(tx *sql.Tx) *mysql.Queries +// }) +// var tx *sql.Tx +// if ok { +// var err error +// tx, err = db.BeginTx(ctx, opts) +// if err != nil { +// return fmt.Errorf("WithTx: %w", err) +// } +// defer tx.Rollback() +// q = w.WithTx(tx) +// } +// err := f(q) +// if err != nil { +// return fmt.Errorf("WithTx: %w", err) +// } +// if tx != nil { +// err := tx.Commit() +// if err != nil { +// return fmt.Errorf("WithTx: %w", err) +// } +// } +// return nil +// }