May1145 пре 1 година
родитељ
комит
878e3c23ef

+ 0 - 1
.idea/vcs.xml

@@ -2,6 +2,5 @@
 <project version="4">
   <component name="VcsDirectoryMappings">
     <mapping directory="" vcs="Git" />
-    <mapping directory="$PROJECT_DIR$/WeChatTrading" vcs="Git" />
   </component>
 </project>

+ 101 - 0
trading-go/controller/chatcontroller.go

@@ -0,0 +1,101 @@
+package controller
+
+import (
+	"encoding/json"
+	"fmt"
+	"github.com/gin-gonic/gin"
+	"github.com/gorilla/websocket"
+	"net/http"
+	"strconv"
+	"trading-go/model"
+	"trading-go/response"
+)
+
+var UP = websocket.Upgrader{
+	ReadBufferSize:  1024,
+	WriteBufferSize: 1024,
+}
+
+type client struct {
+	conn *websocket.Conn
+	pip  chan *model.Message
+}
+
+var Conns map[int64]*client
+
+func init() {
+	Conns = make(map[int64]*client)
+}
+
+// 发送消息
+func send(conn *websocket.Conn, uid int64) {
+	pip := Conns[uid].pip
+	defer func() {
+		close(pip)
+		conn.Close()
+		delete(Conns, uid)
+	}()
+	for {
+		data := <-pip
+		msg, err := json.Marshal(data)
+		if err != nil {
+			fmt.Println("link closed")
+			break
+		}
+		err = conn.WriteMessage(1, msg)
+		if err != nil {
+			fmt.Println(err.Error())
+			continue
+		}
+	}
+}
+
+// 接收消息
+func reception(conn *websocket.Conn, uid int64) {
+	pip := Conns[uid].pip
+	defer func() {
+		close(pip)
+		conn.Close()
+		delete(Conns, uid)
+	}()
+	for {
+		var msg model.Message
+		err := conn.ReadJSON(&msg)
+		if err != nil {
+			fmt.Println(err.Error())
+			break
+		}
+		Conns[msg.To].pip <- &msg
+	}
+}
+
+func Chat(w http.ResponseWriter, rq *http.Request, uid int64) {
+	// 升级为websocket
+	conn, err := UP.Upgrade(w, rq, nil)
+
+	if err != nil {
+		response.Fail(w, err.Error(), http.StatusUpgradeRequired)
+		return
+	}
+
+	pip := make(chan *model.Message, 1024)
+	client := client{
+		conn: conn,
+		pip:  pip,
+	}
+	Conns[uid] = &client
+
+	go send(conn, uid)
+	go reception(conn, uid)
+
+	response.Success(w, "success", nil)
+}
+
+func LinkToServer(c *gin.Context) {
+	uid, err := strconv.Atoi(c.Query("uid"))
+	if err != nil {
+		response.Fail(c.Writer, "failed", 500)
+		return
+	}
+	Chat(c.Writer, c.Request, int64(uid))
+}

+ 35 - 6
trading-go/controller/usercontroller.go

@@ -3,6 +3,7 @@ package controller
 import (
 	"errors"
 	"github.com/gin-gonic/gin"
+	"strconv"
 	"trading-go/model"
 	"trading-go/response"
 	"trading-go/util"
@@ -11,16 +12,17 @@ import (
 // Register
 // @Tags 用户模块
 // @Summary 创建用户
-// @param uid formData string false "用户ID"
+// @param vid formData string false "微信ID"
 // @param name formData string false "用户名"
 // @param avatar formData string false "头像"
 // @Success 200 {string} json{"code","data","message"}
 // @Router /user/register [post]
 func Register(c *gin.Context) {
 	var user model.User
-	user.Uid = c.PostForm("uid")
+	user.Vid = c.PostForm("vid")
 	user.Name = c.PostForm("name")
 	user.Avatar = c.PostForm("avatar")
+	user.Uid = util.GenID()
 	err := user.Register()
 	if err != nil {
 		msg := err.Error()
@@ -34,12 +36,12 @@ func Register(c *gin.Context) {
 // Login
 // @Tags 用户模块
 // @Summary 登录
-// @param uid formData string false "用户ID"
+// @param vid formData string false "微信ID"
 // @Success 200 {string} json{"code","data","token"}
 // @Router /user/login [post]
 func Login(c *gin.Context) {
 	var user model.User
-	user.Uid = c.PostForm("uid")
+	user.Vid = c.PostForm("vid")
 	err, nu := user.Login()
 	if err != nil && err.Error() == "sql: no rows in result set" {
 		err = util.NoSuchUserError
@@ -76,7 +78,9 @@ func Login(c *gin.Context) {
 // @Router /user/modify [post]
 func Modify(c *gin.Context) {
 	var user model.User
-	user.Uid = c.PostForm("uid")
+	id := c.PostForm("uid")
+	uid, err := strconv.Atoi(id)
+	user.Uid = int64(uid)
 	user.Phone = c.PostForm("phone")
 	user.Avatar = c.PostForm("avatar")
 	user.Name = c.PostForm("name")
@@ -89,7 +93,7 @@ func Modify(c *gin.Context) {
 		}
 		return
 	}
-	err := user.Modify()
+	err = user.Modify()
 	if err != nil {
 		msg := err.Error()
 		response.Fail(c.Writer, msg, 500)
@@ -98,3 +102,28 @@ func Modify(c *gin.Context) {
 		response.Success(c.Writer, msg, nil)
 	}
 }
+
+// UserInfo
+// @Tags 用户模块
+// @Summary 获取用户信息
+// @Success 200 {string} json{"code","data"}
+// @Router /user/info [get]
+func UserInfo(c *gin.Context) {
+	var u model.User
+	token := c.GetHeader("Authorization")
+	t, claim, err := util.ParseToken(token)
+	if err != nil {
+		response.Fail(c.Writer, "failed to parse token", 500)
+		return
+	}
+	if !t.Valid {
+		response.Fail(c.Writer, "failed", 400)
+		return
+	}
+	u.Uid = claim.UserId
+	nu, err := u.Info()
+	if err != nil {
+		response.Fail(c.Writer, "failed", 500)
+	}
+	response.Success(c.Writer, "success", nu)
+}

+ 2 - 0
trading-go/go.mod

@@ -6,6 +6,7 @@ require (
 	github.com/KyleBanks/depth v1.2.1 // indirect
 	github.com/PuerkitoBio/purell v1.2.0 // indirect
 	github.com/PuerkitoBio/urlesc v0.0.0-20170810143723-de5bf2ad4578 // indirect
+	github.com/bwmarrin/snowflake v0.3.0 // indirect
 	github.com/bytedance/sonic v1.10.0-rc2 // indirect
 	github.com/cespare/xxhash/v2 v2.1.2 // indirect
 	github.com/chenzhuoyu/base64x v0.0.0-20230717121745-296ad89f973d // indirect
@@ -27,6 +28,7 @@ require (
 	github.com/go-sql-driver/mysql v1.7.1 // indirect
 	github.com/goccy/go-json v0.10.2 // indirect
 	github.com/golang-jwt/jwt v3.2.2+incompatible // indirect
+	github.com/gorilla/websocket v1.5.0 // indirect
 	github.com/hashicorp/hcl v1.0.0 // indirect
 	github.com/jmoiron/sqlx v1.3.5 // indirect
 	github.com/josharian/intern v1.0.0 // indirect

+ 4 - 0
trading-go/go.sum

@@ -44,6 +44,8 @@ github.com/PuerkitoBio/purell v1.2.0 h1:/Jdm5QfyM8zdlqT6WVZU4cfP23sot6CEHA4CS49E
 github.com/PuerkitoBio/purell v1.2.0/go.mod h1:OhLRTaaIzhvIyofkJfB24gokC7tM42Px5UhoT32THBk=
 github.com/PuerkitoBio/urlesc v0.0.0-20170810143723-de5bf2ad4578 h1:d+Bc7a5rLufV/sSk/8dngufqelfh6jnri85riMAaF/M=
 github.com/PuerkitoBio/urlesc v0.0.0-20170810143723-de5bf2ad4578/go.mod h1:uGdkoq3SwY9Y+13GIhn11/XLaGBb4BfwItxLd5jeuXE=
+github.com/bwmarrin/snowflake v0.3.0 h1:xm67bEhkKh6ij1790JB83OujPR5CzNe8QuQqAgISZN0=
+github.com/bwmarrin/snowflake v0.3.0/go.mod h1:NdZxfVWX+oR6y2K0o6qAYv6gIOP9rjG0/E9WsDpxqwE=
 github.com/bytedance/sonic v1.5.0/go.mod h1:ED5hyg4y6t3/9Ku1R6dU/4KyJ48DZ4jPhfY1O2AihPM=
 github.com/bytedance/sonic v1.10.0-rc/go.mod h1:ElCzW+ufi8qKqNW0FY314xriJhyJhuoJ3gFZdAHF7NM=
 github.com/bytedance/sonic v1.10.0-rc2 h1:oDfRZ+4m6AYCOC0GFeOCeYqvBmucy1isvouS2K0cPzo=
@@ -175,6 +177,8 @@ github.com/google/uuid v1.1.2/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+
 github.com/googleapis/gax-go/v2 v2.0.4/go.mod h1:0Wqv26UfaUD9n4G6kQubkQ+KchISgw+vpHVxEJEs9eg=
 github.com/googleapis/gax-go/v2 v2.0.5/go.mod h1:DWXyrwAJ9X0FpwwEdw+IPEYBICEFu5mhpdKc/us6bOk=
 github.com/googleapis/google-cloud-go-testing v0.0.0-20200911160855-bcd43fbb19e8/go.mod h1:dvDLG8qkwmyD9a/MJJN3XJcT3xFxOKAvTZGvuZmac9g=
+github.com/gorilla/websocket v1.5.0 h1:PPwGk2jz7EePpoHN/+ClbZu8SPxiqlu12wZP/3sWmnc=
+github.com/gorilla/websocket v1.5.0/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE=
 github.com/hashicorp/golang-lru v0.5.0/go.mod h1:/m3WP610KZHVQ1SGc6re/UDhFvYD7pJ4Ao+sR/qLZy8=
 github.com/hashicorp/golang-lru v0.5.1/go.mod h1:/m3WP610KZHVQ1SGc6re/UDhFvYD7pJ4Ao+sR/qLZy8=
 github.com/hashicorp/hcl v1.0.0 h1:0Anlzjpi4vEasTeNFn2mLJgTSwt0+6sfsiTG8qcWGx4=

+ 8 - 0
trading-go/model/message.go

@@ -0,0 +1,8 @@
+package model
+
+type Message struct {
+	MsgType int   `json:"msgType"`
+	From    int64 `json:"from"`
+	To      int64 `json:"to"`
+	Content any   `json:"content"`
+}

+ 1 - 0
trading-go/model/relation.go

@@ -0,0 +1 @@
+package model

+ 15 - 4
trading-go/model/user.go

@@ -10,7 +10,8 @@ import (
 
 // User 用户信息模型
 type User struct {
-	Uid    string `json:"uid" db:"uid"`
+	Uid    int64  `json:"uid" db:"uid"`
+	Vid    string `json:"vid" db:"vid"`
 	Name   string `json:"name" db:"name"`
 	Avatar string `json:"avatar" db:"avatar"`
 	Email  string `json:"email" db:"email"`
@@ -19,7 +20,7 @@ type User struct {
 
 func (u User) Register() error {
 	db := common.DB
-	sqlStr := "INSERT INTO users(uid, name, avatar) values (:uid, :name, :avatar)"
+	sqlStr := "INSERT INTO users(uid, vid,name, avatar) values (:uid, :vid, :name, :avatar)"
 	exec, err := db.NamedExec(sqlStr, u)
 	if err != nil {
 		return err
@@ -36,11 +37,12 @@ func (u User) Register() error {
 
 func (u User) Login() (err error, nu User) {
 	db := common.DB
-	sqlStr := "SELECT * FROM users WHERE uid = ?"
-	err = db.Get(&nu, sqlStr, u.Uid)
+	sqlStr := "SELECT * FROM users WHERE vid = ?"
+	err = db.Get(&nu, sqlStr, u.Vid)
 	return
 }
 
+// SPhone 查找电话号码
 func (u User) SPhone() error {
 	var nu User
 	db := common.DB
@@ -55,6 +57,7 @@ func (u User) SPhone() error {
 	return util.PhoneBeUsed
 }
 
+// Modify 修改信息
 func (u User) Modify() error {
 	db := common.DB
 	sqlStr := "UPDATE users set name = ?,avatar = ?, phone = ? WHERE uid = ?"
@@ -71,3 +74,11 @@ func (u User) Modify() error {
 	}
 	return nil
 }
+
+// Info 获取用户信息
+func (u User) Info() (nu User, err error) {
+	db := common.DB
+	sqlStr := "SELECT * FROM users WHERE uid = ?"
+	err = db.Get(&nu, sqlStr, u.Uid)
+	return
+}

+ 2 - 0
trading-go/routine/routine.go

@@ -18,10 +18,12 @@ func GetRoutine() *gin.Engine {
 
 	user := r.Group("user")
 	{
+		user.GET("info", controller.UserInfo)
 		user.POST("modify", controller.Modify)
 		user.POST("login", controller.Login)
 		user.POST("register", controller.Register)
 	}
 
+	r.GET("chat", controller.LinkToServer)
 	return r
 }

+ 2 - 2
trading-go/util/jwt-go.go

@@ -8,11 +8,11 @@ import (
 var jwtKey = []byte("hello223rdwvwdfforl-'dandfdsafdafdsfsmay")
 
 type Claims struct {
-	UserId string
+	UserId int64
 	jwt.StandardClaims
 }
 
-func CreatToken(uid string) (string, error) {
+func CreatToken(uid int64) (string, error) {
 	expirationTime := time.Now().Add(31 * 24 * time.Hour)
 	claims := &Claims{
 		UserId: uid,

+ 38 - 0
trading-go/util/snowflake.go

@@ -0,0 +1,38 @@
+package util
+
+import (
+	"fmt"
+	"github.com/bwmarrin/snowflake"
+	"time"
+)
+
+var node *snowflake.Node
+
+func init() {
+	if err := Init("2021-12-03", 1); err != nil {
+		fmt.Println("Init() failed, err = ", err)
+		return
+	}
+}
+
+func Init(startTime string, machineID int64) (err error) {
+	var st time.Time
+	// 格式化 1月2号下午3时4分5秒  2006年
+	st, err = time.Parse("2006-01-02", startTime)
+	if err != nil {
+		fmt.Println(err)
+		return
+	}
+
+	snowflake.Epoch = st.UnixNano() / 1e6
+	node, err = snowflake.NewNode(machineID)
+	if err != nil {
+		fmt.Println(err)
+		return
+	}
+	return
+}
+
+func GenID() int64 {
+	return node.Generate().Int64()
+}