Commit e674b465 authored by 李宇怀's avatar 李宇怀
Browse files

写了个很朴素的测试

parent 7b23656d
......@@ -14,6 +14,8 @@ const (
ErrShortLinkActive = 8 // 短链被中断使用
ErrShortLinkTime = 9 // 不在可用时间范围
ErrOriginEmpty = 10
BadRequest = 400
InternalError = 500
)
......@@ -14,9 +14,9 @@ func setupLinkController(r *gin.RouterGroup) {
// Implemented in controller package.
}
p := r.Group("/link")
p.POST("/create", controller.ParseTokenMidware(), lcw.Create)
p.POST("/create", lcw.Create)
p.POST("/delete", controller.ParseTokenMidware(), lcw.Delete)
p.GET("/getinfo", controller.ParseTokenMidware(), lcw.GetInfo)
p.GET("/getinfo", lcw.GetInfo)
p.POST("/update", controller.ParseTokenMidware(), lcw.Update)
p.GET("/getlist", controller.ParseTokenMidware(), lcw.GetList)
}
......
......@@ -14,12 +14,13 @@ func setupUserController(r *gin.RouterGroup) {
// Implemented in controller package.
}
p := r.Group("/user")
p.POST("/register", controller.ParseTokenMidware(), lcw.Register)
p.GET("/getveri", controller.ParseTokenMidware(), lcw.GetVeri)
p.POST("/register", lcw.Register)
p.GET("/getveri", lcw.GetVeri)
p.POST("/login", controller.ParseTokenMidware(), lcw.Login)
p.GET("/getinfo", controller.ParseTokenMidware(), lcw.GetInfo)
p.POST("/modifyinfo", controller.ParseTokenMidware(), lcw.ModifyInfo)
p.POST("/modifypwd", controller.ParseTokenMidware(), lcw.ModifyPwd)
p.POST("/logout", controller.ParseTokenMidware(), lcw.ModifyPwd)
}
type UserCtlWrapper struct { //Wrapper类隔离接口具体逻辑
......@@ -115,3 +116,14 @@ func (w *UserCtlWrapper) ModifyPwd(c *gin.Context) {
}
dto.ResponseSuccess(c, "modify password successfully")
}
// logout
func (w *UserCtlWrapper) Logout(c *gin.Context) {
err := w.ctl.Logout(c)
if err != nil {
dto.ResponseFail(c, err)
return
}
dto.ResponseSuccess(c, "logout successfully")
}
......@@ -7,6 +7,7 @@ require (
github.com/dgrijalva/jwt-go v3.2.0+incompatible
github.com/gin-contrib/sessions v0.0.5
github.com/gin-gonic/gin v1.9.1
github.com/go-sql-driver/mysql v1.7.0
github.com/google/uuid v1.3.1
github.com/palantir/stacktrace v0.0.0-20161112013806-78658fd2d177
github.com/sirupsen/logrus v1.9.3
......@@ -29,7 +30,6 @@ require (
github.com/go-playground/locales v0.14.1 // indirect
github.com/go-playground/universal-translator v0.18.1 // indirect
github.com/go-playground/validator/v10 v10.14.0 // indirect
github.com/go-sql-driver/mysql v1.7.0 // indirect
github.com/goccy/go-json v0.10.2 // indirect
github.com/gorilla/context v1.1.1 // indirect
github.com/gorilla/securecookie v1.1.1 // indirect
......
......@@ -6,6 +6,9 @@ import (
"go-svc-tpl/internal/dao/model"
"go-svc-tpl/utils/stacktrace"
"crypto/md5"
"encoding/hex"
"github.com/gin-gonic/gin"
"github.com/sirupsen/logrus"
"gopkg.in/guregu/null.v4"
......@@ -50,23 +53,23 @@ func (c *LinkController) Create(ctx *gin.Context, req *dto.LinkCreateReq) (*dto.
Active: null.BoolFrom(true),
OwnerID: userID,
}
err := dao.DB(ctx).Create(newLink).Error
if err != nil {
logrus.Fatal(err)
return nil, err
if newLink.Short == "" {
newLink.Short = GenerateShort(newLink.Origin)
}
var link model.Link
err = dao.DB(ctx).Where(&model.Link{Short: newLink.Short}).First(&link).Error
err := dao.DB(ctx).Where(&model.Link{Short: newLink.Short}).First(&link).Error
if err != nil {
// 没找到相关记录 说明不是短连接重复
if err == gorm.ErrRecordNotFound {
logrus.Fatal(err)
return nil, err
}
return nil, stacktrace.PropagateWithCode(nil, dto.ErrShortLinkExist, "ErrShortLinkExist")
//logrus.Fatal(err)
//存到数据库中
err = dao.DB(ctx).Create(newLink).Error
if err != nil {
logrus.Error("Internal Error.")
return nil, stacktrace.PropagateWithCode(err, dto.InternalError, "InternalError")
}
// 没有错误 正常创建
return &dto.LinkCreateResp{
Short: newLink.Short,
Origin: newLink.Origin,
......@@ -75,26 +78,31 @@ func (c *LinkController) Create(ctx *gin.Context, req *dto.LinkCreateReq) (*dto.
EndTime: newLink.EndTime,
Active: newLink.Active,
}, nil
}
logrus.Error("Internal Error.")
return nil, stacktrace.PropagateWithCode(err, dto.InternalError, "InternalError")
}
//如果找到了的话
return nil, stacktrace.PropagateWithCode(err, dto.ErrShortLinkExist, "ErrShortLinkExist")
}
// delete
func (c *LinkController) Delete(ctx *gin.Context, req *dto.LinkDeleteReq) error {
userID := ctx.GetUint(model.USER_ID_KEY)
//userID := ctx.GetUint(model.USER_ID_KEY)
deleteLink := &model.Link{
Short: req.Short,
OwnerID: userID,
}
err := dao.DB(ctx).Delete(&deleteLink).Error
err := dao.DB(ctx).Where(&model.Link{Short: deleteLink.Short}).Delete(deleteLink).Error
if err != nil {
logrus.Fatal(err)
return stacktrace.PropagateWithCode(nil, dto.ErrNoShortLink, "ErrNoShortLink")
return err
}
var link model.Link
err = dao.DB(ctx).First(&link, "Short = ?", deleteLink.Short).Error
err = dao.DB(ctx).Where(&model.Link{Short: deleteLink.Short}).First(&link).Error
if err != nil {
if err == gorm.ErrRecordNotFound {
logrus.Fatal(err)
return stacktrace.PropagateWithCode(nil, dto.ErrNoShortLink, "ErrNoShortLink")
return stacktrace.PropagateWithCode(err, dto.ErrNoShortLink, "ErrNoShortLink")
}
return err
}
......@@ -110,14 +118,14 @@ func (c *LinkController) GetInfo(ctx *gin.Context, req *dto.GetLinkInfoReq) (*dt
}
var link model.Link
err := dao.DB(ctx).First(&link, "Short = ?", getinfoLink.Short).Error
err := dao.DB(ctx).Where(&model.Link{Short: getinfoLink.Short}).First(&link).Error
if err != nil {
if err == gorm.ErrRecordNotFound {
logrus.Fatal(err)
return nil, stacktrace.PropagateWithCode(nil, dto.ErrNoShortLink, "ErrNoShortLink")
}
return nil, err
}
return &dto.GetLinkInfoResp{
Short: link.Short,
Origin: link.Origin,
......@@ -142,15 +150,7 @@ func (c *LinkController) Update(ctx *gin.Context, req *dto.LinkUpdateReq) error
}
err := dao.DB(ctx).Updates(&updateLink).Error
if err != nil {
return err
}
var link model.Link
err = dao.DB(ctx).First(&link, "Short = ?", updateLink.Short).Error
if err != nil {
if err == gorm.ErrRecordNotFound {
logrus.Fatal(err)
}
return err
}
return nil
......@@ -179,3 +179,15 @@ func (c *LinkController) GetList(ctx *gin.Context, req *dto.GetLinkListReq) (*dt
return resp, nil
}
// 生成短链接
// 输入长链接(字符串) 返回短链接(字符串)
func GenerateShort(origin string) string {
// 使用MD5哈希函数对长链接进行哈希计算
hash := md5.Sum([]byte(origin))
// 将哈希结果转换为16进制字符串
hashString := hex.EncodeToString(hash[:])
// 取哈希结果的前8个字符作为短链接
short := hashString[:8]
return short
}
package controller
import (
"go-svc-tpl/api/dto"
"context"
"go-svc-tpl/internal/dao"
"go-svc-tpl/internal/dao/model"
"go-svc-tpl/utils/stacktrace"
"net/http/httptest"
"testing"
"time"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/assert"
"gopkg.in/guregu/null.v4"
)
func setup() {
db := testInitDB()
var testDB = func(ctx context.Context) *dao.DBMS {
return &dao.DBMS{db.WithContext(ctx)}
}
dao.DB = testDB
gin.SetMode(gin.TestMode)
}
func TestMain(m *testing.M) {
setup()
m.Run()
teardown()
}
func teardown() {
dao.InitDB()
gin.SetMode(gin.ReleaseMode)
}
// TestLinkController_Create函数用于测试Create函数
func TestLinkController_Create(t *testing.T) {
var tests = []struct {
userID uint
input dto.LinkCreateReq
errorCode stacktrace.ErrorCode
}{
// // 测试用例1:正常创建短链接,不指定short字段,使用默认生成的short值
// {
// userID: 1,
// input: dto.LinkCreateReq{
// Short: "",
// Comment: "test link 1",
// Origin: "https://www.google.com",
// StartTime: null.TimeFrom(time.Now()),
// EndTime: null.TimeFrom(time.Now().Add(24 * time.Hour)),
// },
// errorCode: 0,
// },
// 测试用例2:正常创建短链接,指定short字段,使用自定义的short值
{
userID: 2,
input: dto.LinkCreateReq{
Short: "bing",
Comment: "test link 2",
Origin: "https://www.bing.com",
StartTime: null.TimeFrom(time.Now()),
EndTime: null.TimeFrom(time.Now().Add(24 * time.Hour)),
},
errorCode: 0,
},
// // 测试用例3:创建短链接失败,因为short字段已经存在
// {
// userID: 3,
// input: dto.LinkCreateReq{
// Short: "bing",
// Comment: "test link 3",
// Origin: "https://www.yahoo.com",
// StartTime: null.TimeFrom(time.Now()),
// EndTime: null.TimeFrom(time.Now().Add(24 * time.Hour)),
// },
// errorCode: dto.ErrShortLinkExist,
// },
}
var testLinkCtl ILinkController
testLinkCtl = new(LinkController)
for _, test := range tests {
w := httptest.NewRecorder()
ctx, _ := gin.CreateTestContext(w)
ctx.Set(model.USER_ID_KEY, test.userID)
_, err := testLinkCtl.Create(ctx, &test.input)
if test.errorCode != 0 {
assert.Equal(t, test.errorCode, stacktrace.GetCode(err))
continue
}
assert.NoError(t, err)
}
}
package controller
import (
"go-svc-tpl/internal/dao"
"go-svc-tpl/internal/dao/model"
"net/http"
"strings"
"time"
"github.com/dgrijalva/jwt-go"
"github.com/gin-gonic/gin"
......@@ -9,27 +13,99 @@ import (
// 解析和验证身份认证令牌
func ParseTokenMidware() gin.HandlerFunc {
return func(c *gin.Context) {
// 在此处添加解析和验证令牌的逻辑
tokenString := c.GetHeader("Authorization")
if tokenString == "" {
c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{"error": "Missing token"})
return func(ctx *gin.Context) {
// 获取authorization header
tokenString := ctx.GetHeader("Authorization")
// validate token formate
if tokenString == "" || !strings.HasPrefix(tokenString, "Bearer ") {
ctx.JSON(http.StatusUnauthorized, gin.H{"code": 401, "message": "权限不足"})
ctx.Abort()
return
}
// 解析和验证令牌
token, err := jwt.Parse(tokenString, func(token *jwt.Token) (interface{}, error) {
// 在此处添加密钥获取逻辑
// 通常从配置文件、环境变量等获取密钥
return []byte("your-secret-key"), nil
})
//提取token的有效部分("Bearer "共占7位)
tokenString = tokenString[7:]
token, claims, err := ParseToken(tokenString)
if err != nil || !token.Valid {
c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{"error": "Invalid token"})
ctx.JSON(http.StatusUnauthorized, gin.H{"code": 401, "message": "权限不足"})
ctx.Abort()
return
}
// 验证通过后获取claim 中的userId
userId := claims.UserId
DB := dao.DB
var user model.User
DB(ctx).First(&user, userId)
// 用户不存在
if user.ID == 0 {
ctx.JSON(http.StatusUnauthorized, gin.H{"code": 401, "message": "权限不足"})
ctx.Abort()
return
}
// 令牌验证通过,继续处理请求
c.Next()
// 用户存在将user的信息写入上下文,方便读取
ctx.Set("user", user)
ctx.Next()
}
}
// jwt加密密钥
var jwtKey = []byte("a_secret_crect")
// token的claim
type Claims struct {
UserId uint
jwt.StandardClaims
}
// 发放token
func ReleaseToken(user model.User) (string, error) {
//token的有效期
expirationTime := time.Now().Add(7 * 24 * time.Hour)
claims := &Claims{
//自定义字段
UserId: user.ID,
//标准字段
StandardClaims: jwt.StandardClaims{
//过期时间
ExpiresAt: expirationTime.Unix(),
//发放的时间
IssuedAt: time.Now().Unix(),
//发放者
Issuer: "127.0.0.1",
//主题
Subject: "user token",
},
}
//使用jwt密钥生成token
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
tokenString, err := token.SignedString(jwtKey)
if err != nil {
return "", err
}
//返回token
return tokenString, nil
}
// 从tokenString中解析出claims并返回
func ParseToken(tokenString string) (*jwt.Token, *Claims, error) {
claims := &Claims{}
token, err := jwt.ParseWithClaims(tokenString, claims, func(token *jwt.Token) (i interface{}, err error) {
return jwtKey, nil
})
return token, claims, err
}
......@@ -36,20 +36,24 @@ func (c *ServerController) Link(ctx *gin.Context, req *dto.ServerLinkReq) error
if err := dao.DB(ctx).Where(&model.Link{Short: req.Short}).First(&link).Error; err != nil {
if err == gorm.ErrRecordNotFound {
// 如果找不到对应的长链接,返回404错误
ctx.JSON(http.StatusNotFound, gin.H{"error": "Link not found"})
// 如果找不到对应的长链接,返回 短连接不存在(6) 错误
ctx.JSON(dto.ErrNoShortLink, gin.H{
"error": "Link not found",
})
} else {
// 如果发生其他错误,返回500错误
ctx.JSON(http.StatusInternalServerError, gin.H{"error": "Internal server error"})
ctx.JSON(dto.InternalError, gin.H{
"error": "Internal server error",
})
}
return err
}
// 如果找到了对应的长链接,进行重定向
ctx.Redirect(http.StatusMovedPermanently, link.Origin)
return nil
}
// 返回验证码的图片
func (c *ServerController) Veri(ctx *gin.Context, req *dto.ServerVeriReq) error {
// 获取验证码ID
id := req.Target
......
package controller
import (
"go-svc-tpl/internal/dao/model"
_ "github.com/go-sql-driver/mysql"
"gorm.io/driver/mysql"
"gorm.io/gorm"
)
var DB *gorm.DB
func testInitDB() *gorm.DB {
// 参考 https://github.com/go-sql-driver/mysql#dsn-data-source-name 获取详情
dsn := "root:1351508.@tcp(127.0.0.1:3306)/demo?charset=utf8mb4&parseTime=True&loc=Local"
db, err := gorm.Open(mysql.Open(dsn), &gorm.Config{})
if err != nil {
panic("failed to connect database, err:" + err.Error())
}
//迁移
db.AutoMigrate(&model.Link{})
DB = db
return db
}
func GetDB() *gorm.DB {
return DB
}
......@@ -131,7 +131,9 @@ PropagateWithCode is similar to Propagate but also attaches an error code.
func PropagateWithCode(cause error, code ErrorCode, msg string, vals ...interface{}) error {
if cause == nil {
// Allow calling PropagateWithCode without checking whether there is error
return nil
//return nil
return create(nil, code, msg)
//&stacktrace{cause: nil, code: code, message: msg}
}
return create(cause, code, msg, vals...)
}
......@@ -172,7 +174,7 @@ func GetCode(err error) ErrorCode {
if err, ok := err.(*stacktrace); ok {
return err.code
}
return NoCode
return 1//NoCode
}
type stacktrace struct {
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment