Commit e674b465 authored by 李宇怀's avatar 李宇怀
Browse files

写了个很朴素的测试

parent 7b23656d
...@@ -14,6 +14,8 @@ const ( ...@@ -14,6 +14,8 @@ const (
ErrShortLinkActive = 8 // 短链被中断使用 ErrShortLinkActive = 8 // 短链被中断使用
ErrShortLinkTime = 9 // 不在可用时间范围 ErrShortLinkTime = 9 // 不在可用时间范围
ErrOriginEmpty = 10
BadRequest = 400 BadRequest = 400
InternalError = 500 InternalError = 500
) )
...@@ -14,9 +14,9 @@ func setupLinkController(r *gin.RouterGroup) { ...@@ -14,9 +14,9 @@ func setupLinkController(r *gin.RouterGroup) {
// Implemented in controller package. // Implemented in controller package.
} }
p := r.Group("/link") p := r.Group("/link")
p.POST("/create", controller.ParseTokenMidware(), lcw.Create) p.POST("/create", lcw.Create)
p.POST("/delete", controller.ParseTokenMidware(), lcw.Delete) p.POST("/delete", controller.ParseTokenMidware(), lcw.Delete)
p.GET("/getinfo", controller.ParseTokenMidware(), lcw.GetInfo) p.GET("/getinfo", lcw.GetInfo)
p.POST("/update", controller.ParseTokenMidware(), lcw.Update) p.POST("/update", controller.ParseTokenMidware(), lcw.Update)
p.GET("/getlist", controller.ParseTokenMidware(), lcw.GetList) p.GET("/getlist", controller.ParseTokenMidware(), lcw.GetList)
} }
......
...@@ -14,12 +14,13 @@ func setupUserController(r *gin.RouterGroup) { ...@@ -14,12 +14,13 @@ func setupUserController(r *gin.RouterGroup) {
// Implemented in controller package. // Implemented in controller package.
} }
p := r.Group("/user") p := r.Group("/user")
p.POST("/register", controller.ParseTokenMidware(), lcw.Register) p.POST("/register", lcw.Register)
p.GET("/getveri", controller.ParseTokenMidware(), lcw.GetVeri) p.GET("/getveri", lcw.GetVeri)
p.POST("/login", controller.ParseTokenMidware(), lcw.Login) p.POST("/login", controller.ParseTokenMidware(), lcw.Login)
p.GET("/getinfo", controller.ParseTokenMidware(), lcw.GetInfo) p.GET("/getinfo", controller.ParseTokenMidware(), lcw.GetInfo)
p.POST("/modifyinfo", controller.ParseTokenMidware(), lcw.ModifyInfo) p.POST("/modifyinfo", controller.ParseTokenMidware(), lcw.ModifyInfo)
p.POST("/modifypwd", controller.ParseTokenMidware(), lcw.ModifyPwd) p.POST("/modifypwd", controller.ParseTokenMidware(), lcw.ModifyPwd)
p.POST("/logout", controller.ParseTokenMidware(), lcw.ModifyPwd)
} }
type UserCtlWrapper struct { //Wrapper类隔离接口具体逻辑 type UserCtlWrapper struct { //Wrapper类隔离接口具体逻辑
...@@ -115,3 +116,14 @@ func (w *UserCtlWrapper) ModifyPwd(c *gin.Context) { ...@@ -115,3 +116,14 @@ func (w *UserCtlWrapper) ModifyPwd(c *gin.Context) {
} }
dto.ResponseSuccess(c, "modify password successfully") dto.ResponseSuccess(c, "modify password successfully")
} }
// logout
func (w *UserCtlWrapper) Logout(c *gin.Context) {
err := w.ctl.Logout(c)
if err != nil {
dto.ResponseFail(c, err)
return
}
dto.ResponseSuccess(c, "logout successfully")
}
...@@ -7,6 +7,7 @@ require ( ...@@ -7,6 +7,7 @@ require (
github.com/dgrijalva/jwt-go v3.2.0+incompatible github.com/dgrijalva/jwt-go v3.2.0+incompatible
github.com/gin-contrib/sessions v0.0.5 github.com/gin-contrib/sessions v0.0.5
github.com/gin-gonic/gin v1.9.1 github.com/gin-gonic/gin v1.9.1
github.com/go-sql-driver/mysql v1.7.0
github.com/google/uuid v1.3.1 github.com/google/uuid v1.3.1
github.com/palantir/stacktrace v0.0.0-20161112013806-78658fd2d177 github.com/palantir/stacktrace v0.0.0-20161112013806-78658fd2d177
github.com/sirupsen/logrus v1.9.3 github.com/sirupsen/logrus v1.9.3
...@@ -29,7 +30,6 @@ require ( ...@@ -29,7 +30,6 @@ require (
github.com/go-playground/locales v0.14.1 // indirect github.com/go-playground/locales v0.14.1 // indirect
github.com/go-playground/universal-translator v0.18.1 // indirect github.com/go-playground/universal-translator v0.18.1 // indirect
github.com/go-playground/validator/v10 v10.14.0 // indirect github.com/go-playground/validator/v10 v10.14.0 // indirect
github.com/go-sql-driver/mysql v1.7.0 // indirect
github.com/goccy/go-json v0.10.2 // indirect github.com/goccy/go-json v0.10.2 // indirect
github.com/gorilla/context v1.1.1 // indirect github.com/gorilla/context v1.1.1 // indirect
github.com/gorilla/securecookie v1.1.1 // indirect github.com/gorilla/securecookie v1.1.1 // indirect
......
...@@ -6,6 +6,9 @@ import ( ...@@ -6,6 +6,9 @@ import (
"go-svc-tpl/internal/dao/model" "go-svc-tpl/internal/dao/model"
"go-svc-tpl/utils/stacktrace" "go-svc-tpl/utils/stacktrace"
"crypto/md5"
"encoding/hex"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
"gopkg.in/guregu/null.v4" "gopkg.in/guregu/null.v4"
...@@ -50,51 +53,56 @@ func (c *LinkController) Create(ctx *gin.Context, req *dto.LinkCreateReq) (*dto. ...@@ -50,51 +53,56 @@ func (c *LinkController) Create(ctx *gin.Context, req *dto.LinkCreateReq) (*dto.
Active: null.BoolFrom(true), Active: null.BoolFrom(true),
OwnerID: userID, OwnerID: userID,
} }
err := dao.DB(ctx).Create(newLink).Error if newLink.Short == "" {
if err != nil { newLink.Short = GenerateShort(newLink.Origin)
logrus.Fatal(err)
return nil, err
} }
var link model.Link var link model.Link
err = dao.DB(ctx).Where(&model.Link{Short: newLink.Short}).First(&link).Error err := dao.DB(ctx).Where(&model.Link{Short: newLink.Short}).First(&link).Error
if err != nil { if err != nil {
// 没找到相关记录 说明不是短连接重复 // 没找到相关记录 说明不是短连接重复
if err == gorm.ErrRecordNotFound { if err == gorm.ErrRecordNotFound {
logrus.Fatal(err) //logrus.Fatal(err)
return nil, err //存到数据库中
err = dao.DB(ctx).Create(newLink).Error
if err != nil {
logrus.Error("Internal Error.")
return nil, stacktrace.PropagateWithCode(err, dto.InternalError, "InternalError")
}
// 没有错误 正常创建
return &dto.LinkCreateResp{
Short: newLink.Short,
Origin: newLink.Origin,
Comment: newLink.Comment,
StartTime: newLink.StartTime,
EndTime: newLink.EndTime,
Active: newLink.Active,
}, nil
} }
return nil, stacktrace.PropagateWithCode(nil, dto.ErrShortLinkExist, "ErrShortLinkExist") logrus.Error("Internal Error.")
return nil, stacktrace.PropagateWithCode(err, dto.InternalError, "InternalError")
} }
return &dto.LinkCreateResp{ //如果找到了的话
Short: newLink.Short, return nil, stacktrace.PropagateWithCode(err, dto.ErrShortLinkExist, "ErrShortLinkExist")
Origin: newLink.Origin,
Comment: newLink.Comment,
StartTime: newLink.StartTime,
EndTime: newLink.EndTime,
Active: newLink.Active,
}, nil
} }
// delete // delete
func (c *LinkController) Delete(ctx *gin.Context, req *dto.LinkDeleteReq) error { func (c *LinkController) Delete(ctx *gin.Context, req *dto.LinkDeleteReq) error {
userID := ctx.GetUint(model.USER_ID_KEY) //userID := ctx.GetUint(model.USER_ID_KEY)
deleteLink := &model.Link{ deleteLink := &model.Link{
Short: req.Short, Short: req.Short,
OwnerID: userID,
} }
err := dao.DB(ctx).Delete(&deleteLink).Error err := dao.DB(ctx).Where(&model.Link{Short: deleteLink.Short}).Delete(deleteLink).Error
if err != nil { if err != nil {
logrus.Fatal(err) logrus.Fatal(err)
return stacktrace.PropagateWithCode(nil, dto.ErrNoShortLink, "ErrNoShortLink") return err
} }
var link model.Link var link model.Link
err = dao.DB(ctx).First(&link, "Short = ?", deleteLink.Short).Error err = dao.DB(ctx).Where(&model.Link{Short: deleteLink.Short}).First(&link).Error
if err != nil { if err != nil {
if err == gorm.ErrRecordNotFound { if err == gorm.ErrRecordNotFound {
logrus.Fatal(err) logrus.Fatal(err)
return stacktrace.PropagateWithCode(nil, dto.ErrNoShortLink, "ErrNoShortLink") return stacktrace.PropagateWithCode(err, dto.ErrNoShortLink, "ErrNoShortLink")
} }
return err return err
} }
...@@ -110,14 +118,14 @@ func (c *LinkController) GetInfo(ctx *gin.Context, req *dto.GetLinkInfoReq) (*dt ...@@ -110,14 +118,14 @@ func (c *LinkController) GetInfo(ctx *gin.Context, req *dto.GetLinkInfoReq) (*dt
} }
var link model.Link var link model.Link
err := dao.DB(ctx).First(&link, "Short = ?", getinfoLink.Short).Error err := dao.DB(ctx).Where(&model.Link{Short: getinfoLink.Short}).First(&link).Error
if err != nil { if err != nil {
if err == gorm.ErrRecordNotFound { if err == gorm.ErrRecordNotFound {
logrus.Fatal(err) logrus.Fatal(err)
return nil, stacktrace.PropagateWithCode(nil, dto.ErrNoShortLink, "ErrNoShortLink")
} }
return nil, err return nil, err
} }
return &dto.GetLinkInfoResp{ return &dto.GetLinkInfoResp{
Short: link.Short, Short: link.Short,
Origin: link.Origin, Origin: link.Origin,
...@@ -142,15 +150,7 @@ func (c *LinkController) Update(ctx *gin.Context, req *dto.LinkUpdateReq) error ...@@ -142,15 +150,7 @@ func (c *LinkController) Update(ctx *gin.Context, req *dto.LinkUpdateReq) error
} }
err := dao.DB(ctx).Updates(&updateLink).Error err := dao.DB(ctx).Updates(&updateLink).Error
if err != nil { if err != nil {
return err logrus.Fatal(err)
}
var link model.Link
err = dao.DB(ctx).First(&link, "Short = ?", updateLink.Short).Error
if err != nil {
if err == gorm.ErrRecordNotFound {
logrus.Fatal(err)
}
return err return err
} }
return nil return nil
...@@ -179,3 +179,15 @@ func (c *LinkController) GetList(ctx *gin.Context, req *dto.GetLinkListReq) (*dt ...@@ -179,3 +179,15 @@ func (c *LinkController) GetList(ctx *gin.Context, req *dto.GetLinkListReq) (*dt
return resp, nil return resp, nil
} }
// 生成短链接
// 输入长链接(字符串) 返回短链接(字符串)
func GenerateShort(origin string) string {
// 使用MD5哈希函数对长链接进行哈希计算
hash := md5.Sum([]byte(origin))
// 将哈希结果转换为16进制字符串
hashString := hex.EncodeToString(hash[:])
// 取哈希结果的前8个字符作为短链接
short := hashString[:8]
return short
}
package controller
import (
"go-svc-tpl/api/dto"
"context"
"go-svc-tpl/internal/dao"
"go-svc-tpl/internal/dao/model"
"go-svc-tpl/utils/stacktrace"
"net/http/httptest"
"testing"
"time"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/assert"
"gopkg.in/guregu/null.v4"
)
func setup() {
db := testInitDB()
var testDB = func(ctx context.Context) *dao.DBMS {
return &dao.DBMS{db.WithContext(ctx)}
}
dao.DB = testDB
gin.SetMode(gin.TestMode)
}
func TestMain(m *testing.M) {
setup()
m.Run()
teardown()
}
func teardown() {
dao.InitDB()
gin.SetMode(gin.ReleaseMode)
}
// TestLinkController_Create函数用于测试Create函数
func TestLinkController_Create(t *testing.T) {
var tests = []struct {
userID uint
input dto.LinkCreateReq
errorCode stacktrace.ErrorCode
}{
// // 测试用例1:正常创建短链接,不指定short字段,使用默认生成的short值
// {
// userID: 1,
// input: dto.LinkCreateReq{
// Short: "",
// Comment: "test link 1",
// Origin: "https://www.google.com",
// StartTime: null.TimeFrom(time.Now()),
// EndTime: null.TimeFrom(time.Now().Add(24 * time.Hour)),
// },
// errorCode: 0,
// },
// 测试用例2:正常创建短链接,指定short字段,使用自定义的short值
{
userID: 2,
input: dto.LinkCreateReq{
Short: "bing",
Comment: "test link 2",
Origin: "https://www.bing.com",
StartTime: null.TimeFrom(time.Now()),
EndTime: null.TimeFrom(time.Now().Add(24 * time.Hour)),
},
errorCode: 0,
},
// // 测试用例3:创建短链接失败,因为short字段已经存在
// {
// userID: 3,
// input: dto.LinkCreateReq{
// Short: "bing",
// Comment: "test link 3",
// Origin: "https://www.yahoo.com",
// StartTime: null.TimeFrom(time.Now()),
// EndTime: null.TimeFrom(time.Now().Add(24 * time.Hour)),
// },
// errorCode: dto.ErrShortLinkExist,
// },
}
var testLinkCtl ILinkController
testLinkCtl = new(LinkController)
for _, test := range tests {
w := httptest.NewRecorder()
ctx, _ := gin.CreateTestContext(w)
ctx.Set(model.USER_ID_KEY, test.userID)
_, err := testLinkCtl.Create(ctx, &test.input)
if test.errorCode != 0 {
assert.Equal(t, test.errorCode, stacktrace.GetCode(err))
continue
}
assert.NoError(t, err)
}
}
package controller package controller
import ( import (
"go-svc-tpl/internal/dao"
"go-svc-tpl/internal/dao/model"
"net/http" "net/http"
"strings"
"time"
"github.com/dgrijalva/jwt-go" "github.com/dgrijalva/jwt-go"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
...@@ -9,27 +13,99 @@ import ( ...@@ -9,27 +13,99 @@ import (
// 解析和验证身份认证令牌 // 解析和验证身份认证令牌
func ParseTokenMidware() gin.HandlerFunc { func ParseTokenMidware() gin.HandlerFunc {
return func(c *gin.Context) { return func(ctx *gin.Context) {
// 在此处添加解析和验证令牌的逻辑
tokenString := c.GetHeader("Authorization") // 获取authorization header
if tokenString == "" { tokenString := ctx.GetHeader("Authorization")
c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{"error": "Missing token"})
// validate token formate
if tokenString == "" || !strings.HasPrefix(tokenString, "Bearer ") {
ctx.JSON(http.StatusUnauthorized, gin.H{"code": 401, "message": "权限不足"})
ctx.Abort()
return return
} }
// 解析和验证令牌 //提取token的有效部分("Bearer "共占7位)
token, err := jwt.Parse(tokenString, func(token *jwt.Token) (interface{}, error) { tokenString = tokenString[7:]
// 在此处添加密钥获取逻辑
// 通常从配置文件、环境变量等获取密钥
return []byte("your-secret-key"), nil
})
token, claims, err := ParseToken(tokenString)
if err != nil || !token.Valid { if err != nil || !token.Valid {
c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{"error": "Invalid token"}) ctx.JSON(http.StatusUnauthorized, gin.H{"code": 401, "message": "权限不足"})
ctx.Abort()
return
}
// 验证通过后获取claim 中的userId
userId := claims.UserId
DB := dao.DB
var user model.User
DB(ctx).First(&user, userId)
// 用户不存在
if user.ID == 0 {
ctx.JSON(http.StatusUnauthorized, gin.H{"code": 401, "message": "权限不足"})
ctx.Abort()
return return
} }
// 令牌验证通过,继续处理请求 // 用户存在将user的信息写入上下文,方便读取
c.Next() ctx.Set("user", user)
ctx.Next()
}
}
// jwt加密密钥
var jwtKey = []byte("a_secret_crect")
// token的claim
type Claims struct {
UserId uint
jwt.StandardClaims
}
// 发放token
func ReleaseToken(user model.User) (string, error) {
//token的有效期
expirationTime := time.Now().Add(7 * 24 * time.Hour)
claims := &Claims{
//自定义字段
UserId: user.ID,
//标准字段
StandardClaims: jwt.StandardClaims{
//过期时间
ExpiresAt: expirationTime.Unix(),
//发放的时间
IssuedAt: time.Now().Unix(),
//发放者
Issuer: "127.0.0.1",
//主题
Subject: "user token",
},
}
//使用jwt密钥生成token
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
tokenString, err := token.SignedString(jwtKey)
if err != nil {
return "", err
} }
//返回token
return tokenString, nil
}
// 从tokenString中解析出claims并返回
func ParseToken(tokenString string) (*jwt.Token, *Claims, error) {
claims := &Claims{}
token, err := jwt.ParseWithClaims(tokenString, claims, func(token *jwt.Token) (i interface{}, err error) {
return jwtKey, nil
})
return token, claims, err
} }
...@@ -36,20 +36,24 @@ func (c *ServerController) Link(ctx *gin.Context, req *dto.ServerLinkReq) error ...@@ -36,20 +36,24 @@ func (c *ServerController) Link(ctx *gin.Context, req *dto.ServerLinkReq) error
if err := dao.DB(ctx).Where(&model.Link{Short: req.Short}).First(&link).Error; err != nil { if err := dao.DB(ctx).Where(&model.Link{Short: req.Short}).First(&link).Error; err != nil {
if err == gorm.ErrRecordNotFound { if err == gorm.ErrRecordNotFound {
// 如果找不到对应的长链接,返回404错误 // 如果找不到对应的长链接,返回 短连接不存在(6) 错误
ctx.JSON(http.StatusNotFound, gin.H{"error": "Link not found"}) ctx.JSON(dto.ErrNoShortLink, gin.H{
"error": "Link not found",
})
} else { } else {
// 如果发生其他错误,返回500错误 // 如果发生其他错误,返回500错误
ctx.JSON(http.StatusInternalServerError, gin.H{"error": "Internal server error"}) ctx.JSON(dto.InternalError, gin.H{
"error": "Internal server error",
})
} }
return err return err
} }
// 如果找到了对应的长链接,进行重定向 // 如果找到了对应的长链接,进行重定向
ctx.Redirect(http.StatusMovedPermanently, link.Origin) ctx.Redirect(http.StatusMovedPermanently, link.Origin)
return nil return nil
} }
// 返回验证码的图片
func (c *ServerController) Veri(ctx *gin.Context, req *dto.ServerVeriReq) error { func (c *ServerController) Veri(ctx *gin.Context, req *dto.ServerVeriReq) error {
// 获取验证码ID // 获取验证码ID
id := req.Target id := req.Target
......
package controller
import (
"go-svc-tpl/internal/dao/model"
_ "github.com/go-sql-driver/mysql"
"gorm.io/driver/mysql"
"gorm.io/gorm"
)
var DB *gorm.DB
func testInitDB() *gorm.DB {
// 参考 https://github.com/go-sql-driver/mysql#dsn-data-source-name 获取详情
dsn := "root:1351508.@tcp(127.0.0.1:3306)/demo?charset=utf8mb4&parseTime=True&loc=Local"
db, err := gorm.Open(mysql.Open(dsn), &gorm.Config{})
if err != nil {
panic("failed to connect database, err:" + err.Error())
}
//迁移
db.AutoMigrate(&model.Link{})
DB = db
return db
}
func GetDB() *gorm.DB {
return DB
}
...@@ -131,7 +131,9 @@ PropagateWithCode is similar to Propagate but also attaches an error code. ...@@ -131,7 +131,9 @@ PropagateWithCode is similar to Propagate but also attaches an error code.
func PropagateWithCode(cause error, code ErrorCode, msg string, vals ...interface{}) error { func PropagateWithCode(cause error, code ErrorCode, msg string, vals ...interface{}) error {
if cause == nil { if cause == nil {
// Allow calling PropagateWithCode without checking whether there is error // Allow calling PropagateWithCode without checking whether there is error
return nil //return nil
return create(nil, code, msg)
//&stacktrace{cause: nil, code: code, message: msg}
} }
return create(cause, code, msg, vals...) return create(cause, code, msg, vals...)
} }
...@@ -172,7 +174,7 @@ func GetCode(err error) ErrorCode { ...@@ -172,7 +174,7 @@ func GetCode(err error) ErrorCode {
if err, ok := err.(*stacktrace); ok { if err, ok := err.(*stacktrace); ok {
return err.code return err.code
} }
return NoCode return 1//NoCode
} }
type stacktrace struct { type stacktrace struct {
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment