diff --git a/frontend/src/apis/apis.ts b/frontend/src/apis/apis.ts index 2cca4fb..b191fb8 100644 --- a/frontend/src/apis/apis.ts +++ b/frontend/src/apis/apis.ts @@ -24,7 +24,7 @@ export async function register(email: string, username: string, password: string "CaptchaToken": captchaToken }) }) - return await apiGet(v) + return await apiGet(v) } export async function userInfo(token: string) { diff --git a/frontend/src/views/Register.tsx b/frontend/src/views/Register.tsx index a55c3f5..dfd1c7e 100644 --- a/frontend/src/views/Register.tsx +++ b/frontend/src/views/Register.tsx @@ -20,6 +20,8 @@ import CaptchaWidget from '@/components/CaptchaWidget'; import type { refType as CaptchaWidgetRef } from '@/components/CaptchaWidget' import useTitle from '@/hooks/useTitle'; import { ApiErr } from '@/apis/error'; +import { useSetAtom } from 'jotai'; +import { token, user } from '@/store/store'; export default function SignUp() { const [regErr, setRegErr] = useState(""); @@ -28,6 +30,8 @@ export default function SignUp() { const captchaRef = useRef(null) const [loading, setLoading] = useState(false); useTitle("注册") + const setToken = useSetAtom(token) + const setUserInfo = useSetAtom(user) const checkList = React.useRef>(new Map()) @@ -51,7 +55,15 @@ export default function SignUp() { } setLoading(true) register(d.email ?? "", d.username ?? "", d.password ?? "", captchaToken). - then(() => navigate("/login")). + then(v => { + if (!v) return + setToken(v.token) + setUserInfo({ + uuid: v.uuid, + name: v.name, + }) + navigate("/profile") + }). catch(v => { captchaRef.current?.reload() console.warn(v) diff --git a/handle/user.go b/handle/user.go index ecc0faa..245e015 100644 --- a/handle/user.go +++ b/handle/user.go @@ -33,13 +33,14 @@ func (h *Handel) Reg() http.HandlerFunc { h.handleError(ctx, w, err.Error(), model.ErrUnknown, 500, slog.LevelWarn) return } - err = h.webService.Reg(ctx, u, rip, ip) + lr, err := h.webService.Reg(ctx, u, rip, ip) if err != nil { h.handleErrorService(ctx, w, err) return } - encodeJson(w, model.API[any]{ + encodeJson(w, model.API[model.LoginRep]{ Code: 0, + Data: lr, }) } } diff --git a/service/admin_test.go b/service/admin_test.go index cdca22d..02e3e5b 100644 --- a/service/admin_test.go +++ b/service/admin_test.go @@ -11,25 +11,19 @@ import ( func TestWebService_Auth(t *testing.T) { ctx := context.Background() - err := webService.Reg(ctx, model.UserReg{ + lr, err := webService.Reg(ctx, model.UserReg{ Email: "TestWebService_Auth@xmdhs.com", Password: "TestWebService_Auth", Name: "TestWebService_Auth", CaptchaToken: "", }, "127.0.1.0/24", "127.0.1.0") require.Nil(t, err) + require.Equal(t, lr.Name, "TestWebService_Auth") - l, err := webService.Login(ctx, model.Login{ - Email: "TestWebService_Auth@xmdhs.com", - Password: "TestWebService_Auth", - CaptchaToken: "", - }, "0.0.0.0") + token, err := webService.Auth(ctx, lr.Token) require.Nil(t, err) - token, err := webService.Auth(ctx, l.Token) - require.Nil(t, err) - - assert.Equal(t, token.Subject, l.UUID) + assert.Equal(t, token.Subject, lr.UUID) assert.Equal(t, token.Tid, "1") type args struct { diff --git a/service/user.go b/service/user.go index f07c883..f61445e 100644 --- a/service/user.go +++ b/service/user.go @@ -25,7 +25,7 @@ var ( ErrChangeName = errors.New("离线模式 uuid 不允许修改用户名") ) -func (w *WebService) Reg(ctx context.Context, u model.UserReg, ipPrefix, ip string) error { +func (w *WebService) Reg(ctx context.Context, u model.UserReg, ipPrefix, ip string) (model.LoginRep, error) { var userUuid string if w.config.OfflineUUID { userUuid = utils.UUIDGen(u.Name) @@ -35,21 +35,23 @@ func (w *WebService) Reg(ctx context.Context, u model.UserReg, ipPrefix, ip stri err := w.verifyCaptcha(ctx, u.CaptchaToken, ip) if err != nil { - return fmt.Errorf("Reg: %w", err) + return model.LoginRep{}, fmt.Errorf("Reg: %w", err) } if w.config.MaxIpUser != 0 { c, err := w.client.User.Query().Where(user.RegIPEQ(ipPrefix)).Count(ctx) if err != nil { - return fmt.Errorf("Reg: %w", err) + return model.LoginRep{}, fmt.Errorf("Reg: %w", err) } if c >= w.config.MaxIpUser { - return fmt.Errorf("Reg: %w", ErrRegLimit) + return model.LoginRep{}, fmt.Errorf("Reg: %w", ErrRegLimit) } } p, s := utils.Argon2ID(u.Password) + var du *ent.User + err = utils.WithTx(ctx, w.client, func(tx *ent.Tx) error { count, err := tx.User.Query().Where(user.EmailEQ(u.Email)).ForUpdateA().Count(ctx) if err != nil { @@ -65,7 +67,7 @@ func (w *WebService) Reg(ctx context.Context, u model.UserReg, ipPrefix, ip stri if nameCount != 0 { return ErrExitsName } - du, err := tx.User.Create(). + du, err = tx.User.Create(). SetEmail(u.Email). SetPassword(p). SetSalt(s). @@ -92,9 +94,18 @@ func (w *WebService) Reg(ctx context.Context, u model.UserReg, ipPrefix, ip stri return nil }) if err != nil { - return fmt.Errorf("Reg: %w", err) + return model.LoginRep{}, fmt.Errorf("Reg: %w", err) } - return nil + jwt, err := utilsService.CreateToken(ctx, du, w.client, w.cache, w.prikey, "web") + if err != nil { + return model.LoginRep{}, fmt.Errorf("Login: %w", err) + } + + return model.LoginRep{ + Token: jwt, + Name: u.Name, + UUID: userUuid, + }, nil } func (w *WebService) Login(ctx context.Context, l model.Login, ip string) (model.LoginRep, error) { diff --git a/service/user_test.go b/service/user_test.go index c9363d0..17a6097 100644 --- a/service/user_test.go +++ b/service/user_test.go @@ -143,7 +143,7 @@ func TestWebService_Reg(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - if err := tt.w.Reg(tt.args.ctx, tt.args.u, tt.args.ipPrefix, tt.args.ip); (err != nil) != tt.wantErr { + if _, err := tt.w.Reg(tt.args.ctx, tt.args.u, tt.args.ipPrefix, tt.args.ip); (err != nil) != tt.wantErr { t.Errorf("WebService.Reg() error = %v, wantErr %v", err, tt.wantErr) } })