diff --git a/.vscode/settings.json b/.vscode/settings.json new file mode 100644 index 0000000..42f5e3e --- /dev/null +++ b/.vscode/settings.json @@ -0,0 +1,5 @@ +{ + "gopls": { + "buildFlags":["-tags=wireinject"] + }, +} \ No newline at end of file diff --git a/cmd/testserver/main.go b/cmd/testserver/main.go new file mode 100644 index 0000000..330ba8a --- /dev/null +++ b/cmd/testserver/main.go @@ -0,0 +1,34 @@ +package main + +import ( + "context" + "fmt" + + "github.com/xmdhs/authlib-skin/config" + "github.com/xmdhs/authlib-skin/server" +) + +func main() { + ctx := context.Background() + config := config.Config{ + OfflineUUID: true, + Port: "127.0.0.1:8080", + Log: struct { + Level string + Json bool + }{ + Level: "debug", + }, + Sql: struct{ MysqlDsn string }{ + MysqlDsn: "sizzle1445:jnjjJQ8^YF&8PN@tcp(192.168.20.1)/skin", + }, + Node: 0, + Epoch: 1693645718534, + } + s, c, err := server.InitializeRoute(ctx, config) + if err != nil { + panic(err) + } + defer c() + fmt.Println(s.ListenAndServe()) +} diff --git a/config/config.go b/config/config.go index 976e68c..b635396 100644 --- a/config/config.go +++ b/config/config.go @@ -2,4 +2,14 @@ package config type Config struct { OfflineUUID bool + Port string + Log struct { + Level string + Json bool + } + Sql struct { + MysqlDsn string + } + Node int64 + Epoch int64 } diff --git a/db/mysql/db.go b/db/mysql/db.go index 48af8b0..a92ee49 100644 --- a/db/mysql/db.go +++ b/db/mysql/db.go @@ -37,6 +37,9 @@ func Prepare(ctx context.Context, db DBTX) (*Queries, error) { if q.getUserByEmailStmt, err = db.PrepareContext(ctx, getUserByEmail); err != nil { return nil, fmt.Errorf("error preparing query GetUserByEmail: %w", err) } + if q.getUserProfileByNameStmt, err = db.PrepareContext(ctx, getUserProfileByName); err != nil { + return nil, fmt.Errorf("error preparing query GetUserProfileByName: %w", err) + } if q.listUserStmt, err = db.PrepareContext(ctx, listUser); err != nil { return nil, fmt.Errorf("error preparing query ListUser: %w", err) } @@ -70,6 +73,11 @@ func (q *Queries) Close() error { err = fmt.Errorf("error closing getUserByEmailStmt: %w", cerr) } } + if q.getUserProfileByNameStmt != nil { + if cerr := q.getUserProfileByNameStmt.Close(); cerr != nil { + err = fmt.Errorf("error closing getUserProfileByNameStmt: %w", cerr) + } + } if q.listUserStmt != nil { if cerr := q.listUserStmt.Close(); cerr != nil { err = fmt.Errorf("error closing listUserStmt: %w", cerr) @@ -112,25 +120,27 @@ func (q *Queries) queryRow(ctx context.Context, stmt *sql.Stmt, query string, ar } type Queries struct { - db DBTX - tx *sql.Tx - createUserStmt *sql.Stmt - createUserProfileStmt *sql.Stmt - deleteUserStmt *sql.Stmt - getUserStmt *sql.Stmt - getUserByEmailStmt *sql.Stmt - listUserStmt *sql.Stmt + db DBTX + tx *sql.Tx + createUserStmt *sql.Stmt + createUserProfileStmt *sql.Stmt + deleteUserStmt *sql.Stmt + getUserStmt *sql.Stmt + getUserByEmailStmt *sql.Stmt + getUserProfileByNameStmt *sql.Stmt + listUserStmt *sql.Stmt } func (q *Queries) WithTx(tx *sql.Tx) *Queries { return &Queries{ - db: tx, - tx: tx, - createUserStmt: q.createUserStmt, - createUserProfileStmt: q.createUserProfileStmt, - deleteUserStmt: q.deleteUserStmt, - getUserStmt: q.getUserStmt, - getUserByEmailStmt: q.getUserByEmailStmt, - listUserStmt: q.listUserStmt, + db: tx, + tx: tx, + createUserStmt: q.createUserStmt, + createUserProfileStmt: q.createUserProfileStmt, + deleteUserStmt: q.deleteUserStmt, + getUserStmt: q.getUserStmt, + getUserByEmailStmt: q.getUserByEmailStmt, + getUserProfileByNameStmt: q.getUserProfileByNameStmt, + listUserStmt: q.listUserStmt, } } diff --git a/db/mysql/querier.go b/db/mysql/querier.go index 157f94a..eddf674 100644 --- a/db/mysql/querier.go +++ b/db/mysql/querier.go @@ -13,6 +13,7 @@ type Querier interface { DeleteUser(ctx context.Context, id int64) error GetUser(ctx context.Context, id int64) (User, error) GetUserByEmail(ctx context.Context, email string) (User, error) + GetUserProfileByName(ctx context.Context, name string) (UserProfile, error) ListUser(ctx context.Context) ([]User, error) } diff --git a/db/mysql/query.sql.go b/db/mysql/query.sql.go index 5c2fbcd..75a03bd 100644 --- a/db/mysql/query.sql.go +++ b/db/mysql/query.sql.go @@ -9,16 +9,7 @@ import ( ) const createUser = `-- name: CreateUser :execresult -REPLACE INTO user ( - id, - email, - password, - salt, - state, - reg_time -) -VALUES - (?, ?, ?, ?, ?, ?) + REPLACE INTO user ( id, email, password, salt, state, reg_time ) VALUES (?, ?, ?, ?, ?, ?) ` type CreateUserParams struct { @@ -42,9 +33,7 @@ func (q *Queries) CreateUser(ctx context.Context, arg CreateUserParams) (sql.Res } const createUserProfile = `-- name: CreateUserProfile :execresult -REPLACE INTO ` + "`" + `user_profile` + "`" + ` (` + "`" + `user_id` + "`" + `, ` + "`" + `name` + "`" + `, ` + "`" + `uuid` + "`" + `) -VALUES - (?, ?, ?) + REPLACE INTO ` + "`" + `user_profile` + "`" + ` (` + "`" + `user_id` + "`" + `, ` + "`" + `name` + "`" + `, ` + "`" + `uuid` + "`" + `) VALUES (?, ?, ?) ` type CreateUserProfileParams struct { @@ -58,10 +47,9 @@ func (q *Queries) CreateUserProfile(ctx context.Context, arg CreateUserProfilePa } const deleteUser = `-- name: DeleteUser :exec -DELETE FROM - user -WHERE - id = ? + DELETE +FROM user +WHERE id = ? ` func (q *Queries) DeleteUser(ctx context.Context, id int64) error { @@ -70,14 +58,10 @@ func (q *Queries) DeleteUser(ctx context.Context, id int64) error { } const getUser = `-- name: GetUser :one -SELECT - id, email, password, salt, state, reg_time -FROM - user -WHERE - id = ? -LIMIT - 1 +SELECT id, email, password, salt, state, reg_time +FROM user +WHERE id = ? +LIMIT 1 ` func (q *Queries) GetUser(ctx context.Context, id int64) (User, error) { @@ -95,14 +79,10 @@ func (q *Queries) GetUser(ctx context.Context, id int64) (User, error) { } const getUserByEmail = `-- name: GetUserByEmail :one -SELECT - id, email, password, salt, state, reg_time -FROM - user -WHERE - email = ? -LIMIT - 1 +SELECT id, email, password, salt, state, reg_time +FROM user +WHERE email = ? +LIMIT 1 ` func (q *Queries) GetUserByEmail(ctx context.Context, email string) (User, error) { @@ -119,13 +99,24 @@ func (q *Queries) GetUserByEmail(ctx context.Context, email string) (User, error return i, err } +const getUserProfileByName = `-- name: GetUserProfileByName :one +SELECT user_id, name, uuid +FROM ` + "`" + `user_profile` + "`" + ` +WHERE ` + "`" + `name` + "`" + ` = ? +LIMIT 1 +` + +func (q *Queries) GetUserProfileByName(ctx context.Context, name string) (UserProfile, error) { + row := q.queryRow(ctx, q.getUserProfileByNameStmt, getUserProfileByName, name) + var i UserProfile + err := row.Scan(&i.UserID, &i.Name, &i.Uuid) + return i, err +} + const listUser = `-- name: ListUser :many -SELECT - id, email, password, salt, state, reg_time -FROM - user -ORDER BY - reg_time +SELECT id, email, password, salt, state, reg_time +FROM user +ORDER BY reg_time ` func (q *Queries) ListUser(ctx context.Context) ([]User, error) { diff --git a/db/mysql/sql/query.sql b/db/mysql/sql/query.sql index fa3872c..790a965 100644 --- a/db/mysql/sql/query.sql +++ b/db/mysql/sql/query.sql @@ -1,51 +1,27 @@ -- name: GetUser :one -SELECT - * -FROM - user -WHERE - id = ? -LIMIT - 1; - --- name: ListUser :many -SELECT - * -FROM - user -ORDER BY - reg_time; - --- name: CreateUser :execresult -REPLACE INTO user ( - id, - email, - password, - salt, - state, - reg_time -) -VALUES - (?, ?, ?, ?, ?, ?); - --- name: DeleteUser :exec -DELETE FROM - user -WHERE - id = ?; - --- name: CreateUserProfile :execresult -REPLACE INTO `user_profile` (`user_id`, `name`, `uuid`) -VALUES - (?, ?, ?); - - --- name: GetUserByEmail :one -SELECT - * -FROM - user -WHERE - email = ? -LIMIT - 1; \ No newline at end of file +SELECT * +FROM user +WHERE id = ? +LIMIT 1; +-- name: ListUser :many +SELECT * +FROM user +ORDER BY reg_time; +-- name: CreateUser :execresult + REPLACE INTO user ( id, email, password, salt, state, reg_time ) VALUES (?, ?, ?, ?, ?, ?); +-- name: DeleteUser :exec + DELETE +FROM user +WHERE id = ?; +-- name: CreateUserProfile :execresult + REPLACE INTO `user_profile` (`user_id`, `name`, `uuid`) VALUES (?, ?, ?); +-- name: GetUserByEmail :one +SELECT * +FROM user +WHERE email = ? +LIMIT 1 for update; +-- name: GetUserProfileByName :one +SELECT * +FROM `user_profile` +WHERE `name` = ? +LIMIT 1 for update; \ No newline at end of file diff --git a/db/mysql/sql/schema.sql b/db/mysql/sql/schema.sql index b7929cb..b09c6ae 100644 --- a/db/mysql/sql/schema.sql +++ b/db/mysql/sql/schema.sql @@ -31,4 +31,6 @@ CREATE TABLE IF NOT EXISTS `user_profile` ( user_id BIGINT PRIMARY KEY, name VARCHAR(20) NOT NULL, uuid text NOT NULL -); \ No newline at end of file +); + +CREATE UNIQUE INDEX IF NOT EXISTS name_index ON user_profile (name); \ No newline at end of file diff --git a/go.mod b/go.mod index f66c92c..261288c 100644 --- a/go.mod +++ b/go.mod @@ -3,18 +3,20 @@ module github.com/xmdhs/authlib-skin go 1.21.0 require ( + github.com/bwmarrin/snowflake v0.3.0 github.com/go-playground/validator/v10 v10.15.3 + github.com/go-sql-driver/mysql v1.7.1 + github.com/google/uuid v1.3.1 + github.com/google/wire v0.5.0 github.com/julienschmidt/httprouter v1.3.0 + golang.org/x/crypto v0.7.0 ) require ( - github.com/bwmarrin/snowflake v0.3.0 // indirect github.com/gabriel-vasile/mimetype v1.4.2 // indirect github.com/go-playground/locales v0.14.1 // indirect github.com/go-playground/universal-translator v0.18.1 // indirect - github.com/google/uuid v1.3.1 // indirect github.com/leodido/go-urn v1.2.4 // indirect - golang.org/x/crypto v0.7.0 // indirect golang.org/x/net v0.8.0 // indirect golang.org/x/sys v0.6.0 // indirect golang.org/x/text v0.8.0 // indirect diff --git a/go.sum b/go.sum index 6eaad15..2e6e3df 100644 --- a/go.sum +++ b/go.sum @@ -13,8 +13,14 @@ github.com/go-playground/universal-translator v0.18.1 h1:Bcnm0ZwsGyWbCzImXv+pAJn github.com/go-playground/universal-translator v0.18.1/go.mod h1:xekY+UJKNuX9WP91TpwSH2VMlDf28Uj24BCp08ZFTUY= github.com/go-playground/validator/v10 v10.15.3 h1:S+sSpunYjNPDuXkWbK+x+bA7iXiW296KG4dL3X7xUZo= github.com/go-playground/validator/v10 v10.15.3/go.mod h1:9iXMNT7sEkjXb0I+enO7QXmzG6QCsPWY4zveKFVRSyU= +github.com/go-sql-driver/mysql v1.7.1 h1:lUIinVbN1DY0xBg0eMOzmmtGoHwWBbvnWubQUrtU8EI= +github.com/go-sql-driver/mysql v1.7.1/go.mod h1:OXbVy3sEdcQ2Doequ6Z5BW6fXNQTmx+9S1MCJN5yJMI= +github.com/google/go-cmp v0.2.0/go.mod h1:oXzfMopK8JAjlY9xF4vHSVASa0yLyX7SntLO5aqRK0M= +github.com/google/subcommands v1.0.1/go.mod h1:ZjhPrFU+Olkh9WazFPsl27BQ4UPiG37m3yTrtFlrHVk= github.com/google/uuid v1.3.1 h1:KjJaJ9iWZ3jOFZIf1Lqf4laDRCasjl0BCmnEGxkdLb4= github.com/google/uuid v1.3.1/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= +github.com/google/wire v0.5.0 h1:I7ELFeVBr3yfPIcc8+MWvrjk+3VjbcSzoXm3JVa+jD8= +github.com/google/wire v0.5.0/go.mod h1:ngWDr9Qvq3yZA10YrxfyGELY/AFWGVpy9c1LTRi1EoU= github.com/julienschmidt/httprouter v1.3.0 h1:U0609e9tgbseu3rBINet9P48AI/D3oJs4dN7jwJOQ1U= github.com/julienschmidt/httprouter v1.3.0/go.mod h1:JR6WtHb+2LUe8TCKY3cZOxFyyO8IZAc4RVcycCCAKdM= github.com/leodido/go-urn v1.2.4 h1:XlAE/cm/ms7TE/VMVoduSpNBoyc2dOxHs5MZSwAN63Q= @@ -28,14 +34,19 @@ github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/ github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= github.com/stretchr/testify v1.8.2 h1:+h33VjcLVPDHtOdpUCuF+7gSuG3yGIftsP1YvFihtJ8= github.com/stretchr/testify v1.8.2/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4= +golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= golang.org/x/crypto v0.7.0 h1:AvwMYaRytfdeVt3u6mLaxYtErKYjxA2OXjJ1HHq6t3A= golang.org/x/crypto v0.7.0/go.mod h1:pYwdfH91IfpZVANVyUOhSIPZaFoJGxTFbZhFTx+dXZU= +golang.org/x/net v0.0.0-20190311183353-d8887717615a/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= golang.org/x/net v0.8.0 h1:Zrh2ngAOFYneWTAIAPethzeaQLuHwhuBkuV6ZiRnUaQ= golang.org/x/net v0.8.0/go.mod h1:QVkue5JL9kW//ek3r6jTKnTFis1tRmNAW2P1shuFdJc= +golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.6.0 h1:MVltZSvRTcU2ljQOhs94SXPftV6DCNnZViHeQps87pQ= golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.8.0 h1:57P1ETyNKtuIjB4SRd15iJxuhj8Gc416Y78H3qgMh68= golang.org/x/text v0.8.0/go.mod h1:e1OnstbJyHTd6l/uOt8jFFHp6TRDWZR/bV3emEE/zU8= +golang.org/x/tools v0.0.0-20190422233926-fe54fb35175b/go.mod h1:LCzVGOaR6xXOjkQ3onu1FJEFr0SW1gC7cKk1uF8kGRs= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= diff --git a/handle/user.go b/handle/user.go index 0e95eb1..727c7f5 100644 --- a/handle/user.go +++ b/handle/user.go @@ -2,6 +2,7 @@ package handle import ( "database/sql" + "errors" "log/slog" "net/http" @@ -27,6 +28,11 @@ func Reg(l *slog.Logger, q mysql.Querier, v *validator.Validate, db *sql.DB, sno } err = service.Reg(ctx, u, q, db, snow, c) if err != nil { + if errors.Is(err, service.ErrExistUser) { + l.DebugContext(ctx, err.Error()) + handleError(ctx, w, err.Error(), model.ErrExistUser, 400) + return + } l.WarnContext(ctx, err.Error()) handleError(ctx, w, err.Error(), model.ErrService, 500) return diff --git a/model/const.go b/model/const.go index e0c78ea..7f5cb58 100644 --- a/model/const.go +++ b/model/const.go @@ -6,4 +6,5 @@ const ( OK APIStatus = iota ErrInput ErrService + ErrExistUser ) diff --git a/model/model.go b/model/model.go index 64b209c..452d961 100644 --- a/model/model.go +++ b/model/model.go @@ -8,6 +8,6 @@ type API[T any] struct { type User struct { Email string `validate:"required,email"` - Password string `validate:"required,sha256"` + Password string `validate:"required,min=6,max=50"` Name string `validate:"required,min=3,max=16"` } diff --git a/server/provide.go b/server/provide.go new file mode 100644 index 0000000..81986f0 --- /dev/null +++ b/server/provide.go @@ -0,0 +1,73 @@ +package server + +import ( + "context" + "database/sql" + "fmt" + "log/slog" + "os" + + "github.com/bwmarrin/snowflake" + "github.com/go-playground/validator/v10" + _ "github.com/go-sql-driver/mysql" + "github.com/google/wire" + "github.com/xmdhs/authlib-skin/config" + "github.com/xmdhs/authlib-skin/db/mysql" +) + +func ProvideSlog(c config.Config) slog.Handler { + var level slog.Level + switch c.Log.Level { + case "debug": + level = slog.LevelDebug + case "info": + level = slog.LevelInfo + case "warn": + level = slog.LevelWarn + case "error": + level = slog.LevelError + } + o := &slog.HandlerOptions{Level: level} + + var h slog.Handler + if c.Log.Json { + h = slog.NewJSONHandler(os.Stderr, o) + } else { + h = slog.NewTextHandler(os.Stderr, o) + } + + return h +} + +func ProvideDB(c config.Config) (*sql.DB, func(), error) { + db, err := sql.Open("mysql", c.Sql.MysqlDsn) + if err != nil { + return nil, nil, fmt.Errorf("newDB: %w", err) + } + db.SetMaxOpenConns(20) + db.SetMaxIdleConns(10) + return db, func() { db.Close() }, nil +} + +func ProvideQuerier(ctx context.Context, db *sql.DB) (mysql.Querier, func(), error) { + q, err := mysql.Prepare(ctx, db) + if err != nil { + return nil, nil, fmt.Errorf("newQuerier: %w", err) + } + return q, func() { q.Close() }, nil +} + +func ProvideValidate() *validator.Validate { + return validator.New() +} + +func ProvideSnowflake(c config.Config) (*snowflake.Node, error) { + snowflake.Epoch = c.Epoch + n, err := snowflake.NewNode(c.Node) + if err != nil { + return nil, fmt.Errorf("newSnowflake: %w", err) + } + return n, nil +} + +var Set = wire.NewSet(ProvideSlog, ProvideDB, ProvideQuerier, ProvideValidate, ProvideSnowflake) diff --git a/server/route/middleware.go b/server/route/middleware.go new file mode 100644 index 0000000..e2a9d3f --- /dev/null +++ b/server/route/middleware.go @@ -0,0 +1,14 @@ +package route + +import ( + "net/http" + + "github.com/julienschmidt/httprouter" +) + +func warpCtJSON(handle httprouter.Handle) httprouter.Handle { + return func(w http.ResponseWriter, r *http.Request, p httprouter.Params) { + w.Header().Set("Content-Type", "application/json; charset=utf-8") + handle(w, r, p) + } +} diff --git a/server/route/route.go b/server/route/route.go new file mode 100644 index 0000000..182c62b --- /dev/null +++ b/server/route/route.go @@ -0,0 +1,37 @@ +package route + +import ( + "database/sql" + "fmt" + "log/slog" + + "github.com/bwmarrin/snowflake" + "github.com/go-playground/validator/v10" + "github.com/julienschmidt/httprouter" + "github.com/xmdhs/authlib-skin/config" + "github.com/xmdhs/authlib-skin/db/mysql" + "github.com/xmdhs/authlib-skin/handle" +) + +func NewRoute(l *slog.Logger, q mysql.Querier, v *validator.Validate, db *sql.DB, snow *snowflake.Node, c config.Config) (*httprouter.Router, error) { + r := httprouter.New() + err := newYggdrasil(r) + if err != nil { + return nil, fmt.Errorf("NewRoute: %w", err) + } + err = newSkinApi(r, l, q, v, db, snow, c) + if err != nil { + return nil, fmt.Errorf("NewRoute: %w", err) + } + return r, nil +} + +func newYggdrasil(r *httprouter.Router) error { + r.POST("/api/authserver/authenticate", nil) + return nil +} + +func newSkinApi(r *httprouter.Router, l *slog.Logger, q mysql.Querier, v *validator.Validate, db *sql.DB, snow *snowflake.Node, c config.Config) error { + r.PUT("/api/v1/user/reg", handle.Reg(l, q, v, db, snow, c)) + return nil +} diff --git a/server/server.go b/server/server.go index 5176802..a4fe724 100644 --- a/server/server.go +++ b/server/server.go @@ -1,20 +1,40 @@ package server import ( + "log/slog" "net/http" + "sync/atomic" + "time" "github.com/julienschmidt/httprouter" + "github.com/xmdhs/authlib-skin/config" + "github.com/xmdhs/authlib-skin/utils" ) -func NewYggdrasil(r *httprouter.Router) error { - - r.POST("/api/authserver/authenticate", nil) - return nil -} - -func warpCtJSON(handle httprouter.Handle) httprouter.Handle { - return func(w http.ResponseWriter, r *http.Request, p httprouter.Params) { - w.Header().Set("Content-Type", "application/json; charset=utf-8") - handle(w, r, p) +func NewServer(c config.Config, sl *slog.Logger, route *httprouter.Router) (*http.Server, func()) { + trackid := atomic.Uint64{} + s := &http.Server{ + ReadTimeout: 10 * time.Second, + ReadHeaderTimeout: 5 * time.Second, + WriteTimeout: 20 * time.Second, + Addr: c.Port, + Handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() + if sl.Enabled(ctx, slog.LevelInfo) { + ip, _ := utils.GetIP(r) + trackid.Add(1) + ctx = setCtx(ctx, &reqInfo{ + URL: r.URL.String(), + IP: ip, + TrackId: trackid.Load(), + }) + r = r.WithContext(ctx) + } + if sl.Enabled(ctx, slog.LevelDebug) { + sl.DebugContext(ctx, r.Method) + } + route.ServeHTTP(w, r) + }), } + return s, func() { s.Close() } } diff --git a/server/slog.go b/server/slog.go index e0a47c5..3240ace 100644 --- a/server/slog.go +++ b/server/slog.go @@ -32,7 +32,7 @@ type warpSlogHandle struct { } func (w *warpSlogHandle) Handle(ctx context.Context, r slog.Record) error { - if w.Enabled(ctx, slog.LevelDebug) { + if w.Enabled(ctx, slog.LevelInfo) { ri := getFromCtx(ctx) if ri != nil { r.AddAttrs(slog.String("ip", ri.IP), slog.String("url", ri.URL), slog.Uint64("trackID", ri.TrackId)) @@ -40,3 +40,10 @@ func (w *warpSlogHandle) Handle(ctx context.Context, r slog.Record) error { } return w.Handler.Handle(ctx, r) } + +func NewSlog(h slog.Handler) *slog.Logger { + l := slog.New(&warpSlogHandle{ + Handler: h, + }) + return l +} diff --git a/server/wire.go b/server/wire.go new file mode 100644 index 0000000..44ca990 --- /dev/null +++ b/server/wire.go @@ -0,0 +1,16 @@ +//go:build wireinject + +package server + +import ( + "context" + "net/http" + + "github.com/google/wire" + "github.com/xmdhs/authlib-skin/config" + "github.com/xmdhs/authlib-skin/server/route" +) + +func InitializeRoute(ctx context.Context, c config.Config) (*http.Server, func(), error) { + panic(wire.Build(Set, route.NewRoute, NewSlog, NewServer)) +} diff --git a/server/wire_gen.go b/server/wire_gen.go new file mode 100644 index 0000000..e9d911b --- /dev/null +++ b/server/wire_gen.go @@ -0,0 +1,53 @@ +// Code generated by Wire. DO NOT EDIT. + +//go:generate go run github.com/google/wire/cmd/wire +//go:build !wireinject +// +build !wireinject + +package server + +import ( + "context" + "github.com/xmdhs/authlib-skin/config" + "github.com/xmdhs/authlib-skin/server/route" + "net/http" +) + +import ( + _ "github.com/go-sql-driver/mysql" +) + +// Injectors from wire.go: + +func InitializeRoute(ctx context.Context, c config.Config) (*http.Server, func(), error) { + handler := ProvideSlog(c) + logger := NewSlog(handler) + db, cleanup, err := ProvideDB(c) + if err != nil { + return nil, nil, err + } + querier, cleanup2, err := ProvideQuerier(ctx, db) + if err != nil { + cleanup() + return nil, nil, err + } + validate := ProvideValidate() + node, err := ProvideSnowflake(c) + if err != nil { + cleanup2() + cleanup() + return nil, nil, err + } + router, err := route.NewRoute(logger, querier, validate, db, node, c) + if err != nil { + cleanup2() + cleanup() + return nil, nil, err + } + server, cleanup3 := NewServer(c, logger, router) + return server, func() { + cleanup3() + cleanup2() + cleanup() + }, nil +} diff --git a/service/user.go b/service/user.go index 397e08d..82659a6 100644 --- a/service/user.go +++ b/service/user.go @@ -18,7 +18,10 @@ import ( "github.com/xmdhs/authlib-skin/utils" ) -var ErrExistUser = errors.New("用户已存在") +var ( + ErrExistUser = errors.New("邮箱已存在") + ErrExitsName = errors.New("用户名已存在") +) func Reg(ctx context.Context, u model.User, q mysql.Querier, db *sql.DB, snow *snowflake.Node, c config.Config, @@ -30,7 +33,9 @@ func Reg(ctx context.Context, u model.User, q mysql.Querier, db *sql.DB, snow *s if ou.Email != "" { return fmt.Errorf("Reg: %w", ErrExistUser) } - err = utils.WithTx(ctx, &sql.TxOptions{}, q, db, func(q mysql.Querier) error { + err = utils.WithTx(ctx, &sql.TxOptions{ + Isolation: sql.LevelReadCommitted, + }, q, db, func(q mysql.Querier) error { p, s := utils.Argon2ID(u.Password) userID := snow.Generate().Int64() _, err := q.CreateUser(ctx, mysql.CreateUserParams{ diff --git a/utils/ip.go b/utils/ip.go new file mode 100644 index 0000000..6801a64 --- /dev/null +++ b/utils/ip.go @@ -0,0 +1,38 @@ +package utils + +import ( + "fmt" + "net" + "net/http" + "strings" +) + +func GetIP(r *http.Request) (string, error) { + //Get IP from the X-REAL-IP header + ip := r.Header.Get("X-REAL-IP") + netIP := net.ParseIP(ip) + if netIP != nil { + return ip, nil + } + + //Get IP from X-FORWARDED-FOR header + ips := r.Header.Get("X-FORWARDED-FOR") + splitIps := strings.Split(ips, ",") + for _, ip := range splitIps { + netIP := net.ParseIP(ip) + if netIP != nil { + return ip, nil + } + } + + //Get IP from RemoteAddr + ip, _, err := net.SplitHostPort(r.RemoteAddr) + if err != nil { + return "", err + } + netIP = net.ParseIP(ip) + if netIP != nil { + return ip, nil + } + return "", fmt.Errorf("no valid ip found") +} diff --git a/utils/tx.go b/utils/tx.go index e517287..a3f1b64 100644 --- a/utils/tx.go +++ b/utils/tx.go @@ -14,7 +14,6 @@ func WithTx(ctx context.Context, opts *sql.TxOptions, q mysql.Querier, db *sql.D }) var tx *sql.Tx if ok { - fmt.Println("事务开启") // remove me var err error tx, err = db.BeginTx(ctx, opts) if err != nil {