Commit 6e6f4ce8 authored by 李宇怀's avatar 李宇怀
Browse files

完成task E

parent 9df34dc7
...@@ -14,7 +14,8 @@ const ( ...@@ -14,7 +14,8 @@ const (
ErrShortLinkActive = 8 // 短链被中断使用 ErrShortLinkActive = 8 // 短链被中断使用
ErrShortLinkTime = 9 // 不在可用时间范围 ErrShortLinkTime = 9 // 不在可用时间范围
ErrOriginEmpty = 10 ErrCreateToken = 10 //生成Token失败
ErrGetToken = 11 //获取token失败
BadRequest = 400 BadRequest = 400
InternalError = 500 InternalError = 500
......
...@@ -14,8 +14,8 @@ func setupServerController(r *gin.RouterGroup) { ...@@ -14,8 +14,8 @@ func setupServerController(r *gin.RouterGroup) {
// Implemented in controller package. // Implemented in controller package.
} }
p := r.Group("/server") p := r.Group("/server")
p.GET("/getlink", controller.ParseTokenMidware(), lcw.Link) p.GET("/getlink", lcw.Link)
p.GET("/getveri", controller.ParseTokenMidware(), lcw.Veri) p.GET("/getveri", lcw.Veri)
} }
type ServerCtlWrapper struct { //Wrapper类隔离接口具体逻辑 type ServerCtlWrapper struct { //Wrapper类隔离接口具体逻辑
......
...@@ -16,7 +16,7 @@ func setupUserController(r *gin.RouterGroup) { ...@@ -16,7 +16,7 @@ func setupUserController(r *gin.RouterGroup) {
p := r.Group("/user") p := r.Group("/user")
p.POST("/register", lcw.Register) p.POST("/register", lcw.Register)
p.GET("/getveri", lcw.GetVeri) p.GET("/getveri", lcw.GetVeri)
p.POST("/login", controller.ParseTokenMidware(), lcw.Login) p.POST("/login", lcw.Login)
p.GET("/getinfo", controller.ParseTokenMidware(), lcw.GetInfo) p.GET("/getinfo", controller.ParseTokenMidware(), lcw.GetInfo)
p.POST("/modifyinfo", controller.ParseTokenMidware(), lcw.ModifyInfo) p.POST("/modifyinfo", controller.ParseTokenMidware(), lcw.ModifyInfo)
p.POST("/modifypwd", controller.ParseTokenMidware(), lcw.ModifyPwd) p.POST("/modifypwd", controller.ParseTokenMidware(), lcw.ModifyPwd)
......
...@@ -146,8 +146,6 @@ github.com/google/pprof v0.0.0-20201203190320-1bf35d6f28c2/go.mod h1:kpwsk12EmLe ...@@ -146,8 +146,6 @@ github.com/google/pprof v0.0.0-20201203190320-1bf35d6f28c2/go.mod h1:kpwsk12EmLe
github.com/google/pprof v0.0.0-20201218002935-b9804c9f04c2/go.mod h1:kpwsk12EmLew5upagYY7GY0pfYCcupk39gWOCRROcvE= github.com/google/pprof v0.0.0-20201218002935-b9804c9f04c2/go.mod h1:kpwsk12EmLew5upagYY7GY0pfYCcupk39gWOCRROcvE=
github.com/google/renameio v0.1.0/go.mod h1:KWCgfxg9yswjAJkECMjeO8J8rahYeXnNhOm40UhjYkI= github.com/google/renameio v0.1.0/go.mod h1:KWCgfxg9yswjAJkECMjeO8J8rahYeXnNhOm40UhjYkI=
github.com/google/uuid v1.1.2/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/google/uuid v1.1.2/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
github.com/google/uuid v1.3.1 h1:KjJaJ9iWZ3jOFZIf1Lqf4laDRCasjl0BCmnEGxkdLb4=
github.com/google/uuid v1.3.1/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
github.com/googleapis/gax-go/v2 v2.0.4/go.mod h1:0Wqv26UfaUD9n4G6kQubkQ+KchISgw+vpHVxEJEs9eg= github.com/googleapis/gax-go/v2 v2.0.4/go.mod h1:0Wqv26UfaUD9n4G6kQubkQ+KchISgw+vpHVxEJEs9eg=
github.com/googleapis/gax-go/v2 v2.0.5/go.mod h1:DWXyrwAJ9X0FpwwEdw+IPEYBICEFu5mhpdKc/us6bOk= github.com/googleapis/gax-go/v2 v2.0.5/go.mod h1:DWXyrwAJ9X0FpwwEdw+IPEYBICEFu5mhpdKc/us6bOk=
github.com/googleapis/google-cloud-go-testing v0.0.0-20200911160855-bcd43fbb19e8/go.mod h1:dvDLG8qkwmyD9a/MJJN3XJcT3xFxOKAvTZGvuZmac9g= github.com/googleapis/google-cloud-go-testing v0.0.0-20200911160855-bcd43fbb19e8/go.mod h1:dvDLG8qkwmyD9a/MJJN3XJcT3xFxOKAvTZGvuZmac9g=
......
...@@ -118,11 +118,11 @@ func (c *LinkController) GetInfo(ctx *gin.Context, req *dto.GetLinkInfoReq) (*dt ...@@ -118,11 +118,11 @@ func (c *LinkController) GetInfo(ctx *gin.Context, req *dto.GetLinkInfoReq) (*dt
} }
var link model.Link var link model.Link
err := dao.DB(ctx).Where(&model.Link{Short: getinfoLink.Short}).First(&link).Error err := dao.DB(ctx).Table("Links").Where(&model.Link{Short: getinfoLink.Short}).First(&link).Error
if err != nil { if err != nil {
if err == gorm.ErrRecordNotFound { if err == gorm.ErrRecordNotFound {
logrus.Fatal(err) logrus.Fatal(err)
return nil, stacktrace.PropagateWithCode(nil, dto.ErrNoShortLink, "ErrNoShortLink") return nil, stacktrace.PropagateWithCode(err, dto.ErrNoShortLink, "ErrNoShortLink")
} }
return nil, err return nil, err
} }
...@@ -148,10 +148,10 @@ func (c *LinkController) Update(ctx *gin.Context, req *dto.LinkUpdateReq) error ...@@ -148,10 +148,10 @@ func (c *LinkController) Update(ctx *gin.Context, req *dto.LinkUpdateReq) error
Active: req.Active, Active: req.Active,
OwnerID: userID, OwnerID: userID,
} }
err := dao.DB(ctx).Updates(&updateLink).Error err := dao.DB(ctx).Table("Links").Where(&model.Link{Short: updateLink.Short}).Updates(&updateLink).Error
if err != nil { if err != nil {
logrus.Fatal(err) logrus.Fatal(err)
return err return stacktrace.PropagateWithCode(err, dto.ErrShortLinkExist, "ErrShortLinkExist")
} }
return nil return nil
} }
......
...@@ -39,7 +39,7 @@ func teardown() { ...@@ -39,7 +39,7 @@ func teardown() {
} }
// TestLinkController_Create函数用于测试Create函数 // TestLinkController_Create函数用于测试Create函数
func TestLinkController_Create(t *testing.T) { func TestCreate(t *testing.T) {
var tests = []struct { var tests = []struct {
userID uint userID uint
...@@ -100,7 +100,7 @@ func TestLinkController_Create(t *testing.T) { ...@@ -100,7 +100,7 @@ func TestLinkController_Create(t *testing.T) {
} }
func TestLinkController_Delete(t *testing.T) { func TestDelete(t *testing.T) {
var tests = []struct { var tests = []struct {
userID uint userID uint
...@@ -177,3 +177,177 @@ func TestGetinfo(t *testing.T) { ...@@ -177,3 +177,177 @@ func TestGetinfo(t *testing.T) {
assert.NoError(t, err) assert.NoError(t, err)
} }
} }
func TestUpdate(t *testing.T) {
var tests = []struct {
userID uint
input dto.LinkUpdateReq
errorCode stacktrace.ErrorCode
}{
// 测试用例1:正常更新短链接,指定short字段搜索短连接并更新为req中的值
{
userID: 2,
input: dto.LinkUpdateReq{
Short: "bing",
Comment: "test updata 2",
Origin: "https://www.bing.com",
StartTime: null.NewTime(time.Date(2023, 7, 30, 12, 0, 12, 0, time.Local), true),
EndTime: null.NewTime(time.Date(2023, 8, 1, 12, 0, 12, 0, time.Local), true),
Active: null.BoolFrom(true),
},
errorCode: dto.NoErr,
},
}
testLinkCtl := new(LinkController)
for _, test := range tests {
w := httptest.NewRecorder()
ctx, _ := gin.CreateTestContext(w)
ctx.Set(model.USER_ID_KEY, test.userID)
err := testLinkCtl.Update(ctx, &test.input)
if test.errorCode != 0 {
assert.Equal(t, test.errorCode, stacktrace.GetCode(err))
continue
}
assert.NoError(t, err)
}
}
func TestRegister(t *testing.T) {
var tests = []struct {
userID uint
input dto.UserRegisterReq
errorCode stacktrace.ErrorCode
}{
// 测试用例1:正常创建短链接,指定short字段,使用自定义的short值
{
userID: 3,
input: dto.UserRegisterReq{
Email: "3220000000@zju.edu.cn",
Name: "3333",
Password: "123",
},
errorCode: dto.NoErr,
},
}
testUserCtl := new(UserController)
for _, test := range tests {
w := httptest.NewRecorder()
ctx, _ := gin.CreateTestContext(w)
ctx.Set(model.USER_ID_KEY, test.userID)
err := testUserCtl.Register(ctx, &test.input)
if test.errorCode != 0 {
assert.Equal(t, test.errorCode, stacktrace.GetCode(err))
continue
}
assert.NoError(t, err)
}
}
func TestLogin(t *testing.T) {
var tests = []struct {
userID uint
input dto.UserLoginReq
errorCode stacktrace.ErrorCode
}{
// 测试用例1:密码错误
{
userID: 2,
input: dto.UserLoginReq{
CAPTCHAID: "1",
CAPTCHAValue: "1",
Email: "3220106025@zju.edu.cn",
Password: "1234",
},
errorCode: dto.NoErr,
},
}
testUserCtl := new(UserController)
for _, test := range tests {
w := httptest.NewRecorder()
ctx, _ := gin.CreateTestContext(w)
ctx.Set(model.USER_ID_KEY, test.userID)
err := testUserCtl.Login(ctx, &test.input)
if test.errorCode != 0 {
assert.Equal(t, test.errorCode, stacktrace.GetCode(err))
continue
}
assert.NoError(t, err)
}
}
// 测试修改信息函数
func TestUserModifyInfoReq(t *testing.T) {
var tests = []struct {
userID uint
input dto.UserModifyInfoReq
errorCode stacktrace.ErrorCode
}{
// 测试用例1:修改邮箱
{
userID: 1,
input: dto.UserModifyInfoReq{
Email: "3220406011@thu.edu.cn",
ID: 1,
Name: "thu",
},
errorCode: dto.NoErr,
},
}
testUserCtl := new(UserController)
for _, test := range tests {
w := httptest.NewRecorder()
ctx, _ := gin.CreateTestContext(w)
ctx.Set(model.USER_ID_KEY, test.userID)
err := testUserCtl.ModifyInfo(ctx, &test.input)
if test.errorCode != 0 {
assert.Equal(t, test.errorCode, stacktrace.GetCode(err))
continue
}
assert.NoError(t, err)
}
}
func TestUserModifyPwdReq(t *testing.T) {
var tests = []struct {
userID uint
input dto.UserModifyPwdReq
errorCode stacktrace.ErrorCode
}{
// 测试用例1:旧密码不正确
{
userID: 3,
input: dto.UserModifyPwdReq{
NewPwd: "111111",
OldPwd: "123",
},
errorCode: dto.NoErr,
},
}
testUserCtl := new(UserController)
for _, test := range tests {
w := httptest.NewRecorder()
ctx, _ := gin.CreateTestContext(w)
ctx.Set(model.USER_ID_KEY, test.userID)
err := testUserCtl.ModifyPwd(ctx, &test.input)
if test.errorCode != 0 {
assert.Equal(t, test.errorCode, stacktrace.GetCode(err))
continue
}
assert.NoError(t, err)
}
}
package controller package controller
import ( import (
"go-svc-tpl/api/dto"
"go-svc-tpl/internal/dao" "go-svc-tpl/internal/dao"
"go-svc-tpl/internal/dao/model" "go-svc-tpl/internal/dao/model"
"net/http" "net/http"
...@@ -16,7 +17,10 @@ func ParseTokenMidware() gin.HandlerFunc { ...@@ -16,7 +17,10 @@ func ParseTokenMidware() gin.HandlerFunc {
return func(ctx *gin.Context) { return func(ctx *gin.Context) {
// 获取authorization header // 获取authorization header
tokenString := ctx.GetHeader("Authorization") tokenString, err := ctx.Cookie("token")
if err != nil {
ctx.JSON(dto.ErrGetToken, gin.H{"code": dto.ErrGetToken, "message": "ErrGetToken."})
}
// validate token formate // validate token formate
if tokenString == "" || !strings.HasPrefix(tokenString, "Bearer ") { if tokenString == "" || !strings.HasPrefix(tokenString, "Bearer ") {
...@@ -39,7 +43,7 @@ func ParseTokenMidware() gin.HandlerFunc { ...@@ -39,7 +43,7 @@ func ParseTokenMidware() gin.HandlerFunc {
userId := claims.UserId userId := claims.UserId
DB := dao.DB DB := dao.DB
var user model.User var user model.User
DB(ctx).First(&user, userId) DB(ctx).Table("User").First(&user, userId)
// 用户不存在 // 用户不存在
if user.ID == 0 { if user.ID == 0 {
...@@ -61,6 +65,10 @@ var jwtKey = []byte("a_secret_crect") ...@@ -61,6 +65,10 @@ var jwtKey = []byte("a_secret_crect")
// token的claim // token的claim
type Claims struct { type Claims struct {
UserId uint UserId uint
Email string `json:"email"` // 用户邮箱​
Name string `json:"name"` // 用户名​
Password string `json:"password"` // 用户密码​
jwt.StandardClaims jwt.StandardClaims
} }
...@@ -73,7 +81,11 @@ func ReleaseToken(user model.User) (string, error) { ...@@ -73,7 +81,11 @@ func ReleaseToken(user model.User) (string, error) {
claims := &Claims{ claims := &Claims{
//自定义字段 //自定义字段
UserId: user.ID, UserId: user.ID,
Email: user.Email,
Name: user.Name,
Password: user.Password,
//标准字段 //标准字段
StandardClaims: jwt.StandardClaims{ StandardClaims: jwt.StandardClaims{
......
...@@ -7,16 +7,14 @@ import ( ...@@ -7,16 +7,14 @@ import (
"net/http" "net/http"
"regexp" "regexp"
"go-svc-tpl/utils/stacktrace"
"bytes" "bytes"
"go-svc-tpl/utils/stacktrace"
"path" "path"
"strings" "strings"
"time" "time"
"github.com/dchest/captcha" "github.com/dchest/captcha"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/google/uuid"
"golang.org/x/crypto/bcrypt" "golang.org/x/crypto/bcrypt"
"gorm.io/gorm" "gorm.io/gorm"
) )
...@@ -159,15 +157,19 @@ func (c *UserController) Login(ctx *gin.Context, req *dto.UserLoginReq) error { ...@@ -159,15 +157,19 @@ func (c *UserController) Login(ctx *gin.Context, req *dto.UserLoginReq) error {
} }
//设置cookie //设置cookie
//ctx.SetCookie("id", strconv.FormatUint(uint64(userID), 10), 3600, "/", "localhost", false, true) //ctx.SetCookie("id", strconv.FormatUint(uint64(userID), 10), 3600, "/", "localhost", false, true)
ID = user.ID //使用jwt密钥生成token
sessionID = uuid.New().String() tokenString, err := ReleaseToken(*newUser)
ctx.SetCookie("session_id", sessionID, 86400, "/", "localhost", false, true) if err != nil {
return stacktrace.PropagateWithCode(err, dto.ErrCreateToken, "ErrCreateToken.")
}
ctx.SetCookie("token", tokenString, 86400, "/", "localhost", false, true)
return nil return nil
} }
// Logout // Logout
func (c *UserController) Logout(ctx *gin.Context) error { func (c *UserController) Logout(ctx *gin.Context) error {
ctx.SetCookie("id", "", -1, "/", "localhost", false, true) //删除对应的cookie
ctx.SetCookie("token", "", -1, "/", "localhost", false, true)
return nil return nil
} }
...@@ -216,11 +218,18 @@ func (c *UserController) ModifyPwd(ctx *gin.Context, req *dto.UserModifyPwdReq) ...@@ -216,11 +218,18 @@ func (c *UserController) ModifyPwd(ctx *gin.Context, req *dto.UserModifyPwdReq)
if err != nil { if err != nil {
return stacktrace.PropagateWithCode(err, dto.ErrUserNotFound, "ErrUserNotFound.") return stacktrace.PropagateWithCode(err, dto.ErrUserNotFound, "ErrUserNotFound.")
} }
if user.Password != req.OldPwd {
return stacktrace.PropagateWithCode(err, dto.ErrPassword, "ErrPassword.") err = bcrypt.CompareHashAndPassword([]byte(user.Password), []byte(req.OldPwd))
if err != nil {
return stacktrace.PropagateWithCode(nil, dto.ErrPassword, "ErrPassword.")
}
//密码Hash处理
hashedBytes, err := bcrypt.GenerateFromPassword([]byte(req.NewPwd), bcrypt.DefaultCost)
if err != nil {
return stacktrace.PropagateWithCode(err, dto.InternalError, "InternalError.")
} }
newUser := &model.User{ newUser := &model.User{
Password: req.NewPwd, Password: string(hashedBytes),
} }
err = dao.DB(ctx).Model(&model.User{ID: userID}).Updates(&newUser).Error err = dao.DB(ctx).Model(&model.User{ID: userID}).Updates(&newUser).Error
if err != nil { if err != nil {
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment