之前练手使用基于gin的go web项目

发布时间 2023-04-06 14:38:41作者: 干炸小黄鱼

目录结构:

`-- demo
    |-- cmd
    |   |-- api.go
    |   `-- root.go
    |-- common
    |   `-- consts
    |       `-- consts.go
    |-- config
    |   `-- viper.go
    |-- config.toml
    |-- go.mod
    |-- go.sum
    |-- handler
    |   |-- jwt.go
    |   |-- h_hello.go
    |   |-- helper.go
    |-- main.go
    |-- model
    |   |-- db.go
    |   |-- m_user.go
    |   |-- helper.go
    |-- scripts
    |   |-- Readme.md
    |   |-- package_cms.sh
    |   `-- package_tar.sh
    |-- service
    |   |-- product_service.go
    `-- util

main.go

package main

import "demo/cmd"

var buildTime, gitHash string

func main() {
	cmd.Execute(buildTime, gitHash)
}

cmd/root.go

package cmd

import (
	"github.com/fatih/color"
	"github.com/spf13/cobra"
	"os"
	"demo/config"
	"demo/util/log"
	"runtime"
)

var rootCmd = &cobra.Command{
	Use:   "demo",
	Short: "demo",
	Long:  "demo",
	Run: func(cmd *cobra.Command, args []string) {
		if isShowVersion {
			color.HiYellow("Golang Env: %s %s/%s", runtime.Version(), runtime.GOOS, runtime.GOARCH)
			color.Cyan("UTC build time: %s", buildTime)
			color.Yellow("Build from repo version: %s", gitHash)

		}
	},
}

var buildTime, gitHash, configFile string
var verbose, isShowVersion bool

func Execute(bTime, gHash string) {
	buildTime = bTime
	gitHash = gHash
	if err := rootCmd.Execute(); err != nil {
		log.Error(err)
		os.Exit(1)
	}
}

func init() {
	cobra.OnInitialize(initFunc)
	rootCmd.Flags().BoolVarP(&isShowVersion, "version", "V", false, "show binary build infomation")
	rootCmd.PersistentFlags().BoolVarP(&verbose, "verbose", "v", true, "verbose")
	rootCmd.PersistentFlags().StringVar(&configFile, "config", "config", "config file path")
}

func initFunc() {
	config.InitViper(configFile)
}

package cmd

import (
	"fmt"
	"pms/config"
	"pms/handler"
	"pms/model"
	"pms/util/log"
	"time"

	"github.com/gin-contrib/sessions"
	"github.com/gin-contrib/sessions/cookie"
	"github.com/gin-gonic/gin"
	"github.com/go-co-op/gocron"
	"github.com/spf13/cobra"
)

var apiCmd = &cobra.Command{
	Use:   "api",
	Short: "run asset crawler api",
	Long:  "",
	Run: func(cmd *cobra.Command, args []string) {
		port := config.GetInt("app.port")
		addr := fmt.Sprintf("0.0.0.0:%d", port)
		log.Debug("start demo api...")
		if err := initDb(); err != nil {
			log.Fatal("数据库初始化失败:", err)
		}

		if err := initCorn(); err != nil {
			log.Fatal("定时任务注册失败:", err)
		}

		if err := runRESTFulAPI(addr); err != nil {
			log.Fatal("run pms api failed, err : ", err)
		}
	},
}

func init() {
	rootCmd.AddCommand(apiCmd)
}

func runRESTFulAPI(addr string) error {
	r := gin.New()
	registerGinHandler(r)
	return r.Run(addr)
}

func initDb() error {
	var db model.DB
	db = new(model.LocalDb)
	err := db.Create()
	if err != nil {
		fmt.Println("LocalDb 初始化连接失败")
		return err
	}
	//defer db.Close()

	db = new(model.ProductDb)
	if err := db.Create(); err != nil {
		fmt.Println("ProductDb 初始化连接失败")
		return err
	}
	//defer db.Close()
	return nil
}

func initCorn() error {
	timezone, _ := time.LoadLocation("Asia/Shanghai")
	s := gocron.NewScheduler(timezone)
	//每周五晚上10点准时打包
	timeStr := config.GetString("cron.tar_time_set")
	_, err := s.Every(1).Day().At(timeStr).Do(func() {
		go handler.ReleaseFromCorn()
	})
	if err != nil {
		fmt.Println("定时任务初始化失败", err)
	}
	s.StartAsync()
	return nil
}

// 注册路由
func registerGinHandler(r *gin.Engine) {
	apiG := r.Group("/api/v1")
	apiG.Use(crossDomainMiddleware())
	store := cookie.NewStore([]byte("demo"))
	apiG.Use(sessions.Sessions("sessions", store))
	apiG.GET("hello", handler.Hello)
	apiG.POST("login", handler.Login)
	apiG.Static("images", handler.GetImagePath())

	//登陆认证组
	loginG := apiG.Use(handler.JWTCheck())
	{
	    loginG.GET("auth_hello", handler.Hello)
	}
}

//跨域中间件
func crossDomainMiddleware() gin.HandlerFunc {
	return func(c *gin.Context) {
		start := time.Now()
		c.Header("Access-Control-Allow-Origin", "*")
		c.Next()
		latency := time.Since(start)
		log.Infof("%s %s cost: %v.", c.Request.Method, c.Request.RequestURI, latency)
	}
}

config/config.go

package config

import (
	"github.com/spf13/viper"
	"demo/util/log"
)

func InitViper(configFile string) {
	viper.SetConfigName(configFile)
	viper.AddConfigPath(".")
	err := viper.ReadInConfig()
	if err != nil {
		log.Error("application configuration'initialization is failed ", err)
	}
}

func GetString(key string) string {
	return viper.GetString(key)
}

func GetInt(key string) int {
	return viper.GetInt(key)
}

func GetBool(key string) bool {
	return viper.GetBool(key)
}

func GetStringMapString(key string) map[string]string {
	return viper.GetStringMapString(key)
}

func GetStringSlice(key string) []string {
	return viper.GetStringSlice(key)
}

db.go

package model

import (
	"fmt"
	"github.com/jinzhu/gorm"
	_ "github.com/jinzhu/gorm/dialects/mysql"
	"github.com/sirupsen/logrus"
	"demo/config"
)

var db *gorm.DB
var productDb *gorm.DB

type DB interface {
	Create() error
	Close()
}

type LocalDb struct {
}

func (l LocalDb) Create() error {
	dbUser := config.GetString("db.user")
	dbHost := config.GetString("db.host")
	dbPassword := config.GetString("db.password")
	dbPort := config.GetInt("db.port")
	dbName := config.GetString("db.dbname")
	dbSqlMode := config.GetString("db.sql_mode")

	conn := fmt.Sprintf("%s:%s@(%s:%d)/%s?charset=utf8&parseTime=True&loc=Local&sql_mode='%s'", dbUser, dbPassword, dbHost, dbPort, dbName, dbSqlMode)
	mysqlDb, err := gorm.Open("mysql", conn)
	if err != nil {
		return err
	}
	db = mysqlDb
	db.LogMode(true)
	logrus.New()
	return nil
}

func (l LocalDb) Close() {
	if db != nil {
		err := db.Close()
		if err != nil {
			logrus.WithError(err).Error("close LocalDb db failed")
		}
	}
}

model/jwt.go

package handler

import (
	"errors"
	"github.com/gin-gonic/gin"
	"demo/model"
	"strings"
)

const tokenPrefix = "Bearer "
const bearerLen = len(tokenPrefix)

func JWTCheck() gin.HandlerFunc {
	return func(c *gin.Context) {
		rawToken := c.GetHeader("Authorization")
		if len(rawToken) < bearerLen {
			handleErrorAuth(c, errors.New("未找到承载令牌"))
			return
		}

		token := strings.TrimSpace(rawToken[bearerLen:])
		uid, err := model.JwtParseUid(token)
		if handleErrorAuth(c, err) {
			return
		}

		c.Set(jwtAuthedUserIdKey, uid)
		c.Next()
	}
}

model/helper.go

package model

import (
	"errors"
	"fmt"
	"github.com/golang-jwt/jwt/v5"
	"github.com/sirupsen/logrus"
	"demo/config"
	"strconv"
	"time"
)

type jwtObj struct {
	User
	Token    string    `json:"token"`
	Expire   time.Time `json:"expire"`
	ExpireTs int64     `json:"expire_ts"`
}

func jwtGenerateToken(m *User) (*jwtObj, error) {
	m.Password = ""
	expireAfterTime := time.Hour * time.Duration(config.GetInt("app.jwt_expire_hour"))
	iss := config.GetString("app.name")
	appSecret := config.GetString("app.secret")
	expireTime := time.Now().Add(expireAfterTime)

	claim := jwt.RegisteredClaims{
		ExpiresAt: jwt.NewNumericDate(expireTime),
		NotBefore: jwt.NewNumericDate(time.Now()),
		IssuedAt:  jwt.NewNumericDate(time.Now()),
		Issuer:    iss,
		ID:        fmt.Sprintf("%d", m.ID),
	}

	token := jwt.NewWithClaims(jwt.SigningMethodHS256, claim)
	tokenString, err := token.SignedString([]byte(appSecret))
	if err != nil {
		logrus.WithError(err).Fatal("config is wrong, can not generate jwt")
	}
	data := &jwtObj{User: *m, Token: tokenString, Expire: expireTime, ExpireTs: expireTime.Unix()}
	return data, err
}

func JwtParseUid(tokenString string) (uint, error) {
	if tokenString == "" {
		return 0, errors.New("授权承载中未找到令牌")
	}
	token, err := jwt.ParseWithClaims(tokenString, &jwt.RegisteredClaims{}, func(token *jwt.Token) (interface{}, error) {
		if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok {
			return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"])
		}
		secret := config.GetString("app.secret")
		return []byte(secret), nil
	})
	if err != nil {
		return 0, errors.New("token 解析失败" + err.Error())
	}

	if !token.Valid {
		if errors.Is(err, jwt.ErrTokenMalformed) {
			return 0, errors.New("token格式错误,无法解析")
		} else if errors.Is(err, jwt.ErrTokenExpired) || errors.Is(err, jwt.ErrTokenNotValidYet) {
			return 0, errors.New("token时效性错误")
		} else {
			return 0, errors.New("token非法")
		}
	}

	if claims, ok := token.Claims.(*jwt.RegisteredClaims); ok {
		uid, err := strconv.ParseUint(claims.ID, 10, 64)
		return uint(uid), err
	} else {
		return 0, errors.New("token claims解析失败")
	}
}

type Pagination struct {
	PageSize uint `json:"page_size" form:"page_size"`
	Page     uint `json:"page" form:"page"`
}

//FixPageAndSize
//修正分页的参数
func (p *Pagination) FixPageAndSize() {
	if p.PageSize == 0 {
		p.PageSize = 10
	}
	if p.Page == 0 {
		p.Page = 1
	}
}

handler/helper.go

package handler

import (
	"bytes"
	"crypto/md5"
	"encoding/hex"
	"encoding/xml"
	"errors"
	"fmt"
	"github.com/gin-gonic/gin"
	"io/ioutil"
	"net/http"
	"os"
	"os/exec"
	"pms/model"
	"pms/util/log"
	"reflect"
	"regexp"
	"strconv"
)

const (
	StatusOK        uint = 200
	StatusAuthError      = 300
	StatusError          = 400
)

func jsonError(c *gin.Context, status uint, msg string) {
	c.AbortWithStatusJSON(http.StatusOK, gin.H{"status": status, "msg": msg})
}
func jsonData(c *gin.Context, data interface{}) {
	c.AbortWithStatusJSON(http.StatusOK, gin.H{"status": StatusOK, "data": data})
}
func jsonPagination(c *gin.Context, list interface{}, total uint, p model.Pagination) {
	c.AbortWithStatusJSON(http.StatusOK, gin.H{"status": StatusOK, "data": list, "total": total, "page": p.Page, "page_size": p.PageSize})
}

func jsonErrorAuth(c *gin.Context, status uint, msg string) {
	c.AbortWithStatusJSON(http.StatusOK, gin.H{"status": status, "msg": msg})
}
func handleError(c *gin.Context, err error) bool {
	if err != nil {
		jsonError(c, StatusError, err.Error())
		return true
	}
	return false
}
func handleErrorAuth(c *gin.Context, err error) bool {
	if err != nil {
		jsonErrorAuth(c, StatusAuthError, err.Error())
		return true
	}
	return false
}