123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189 |
- package controller
- import (
- "encoding/json"
- "fmt"
- "github.com/gin-gonic/gin"
- "github.com/gorilla/websocket"
- "net/http"
- "strconv"
- "time"
- "trading-go/model"
- "trading-go/response"
- )
- var UP = websocket.Upgrader{
- ReadBufferSize: 1024,
- WriteBufferSize: 1024,
- }
- type client struct {
- conn *websocket.Conn
- pip chan *model.Message
- }
- var Conns map[uint]*client
- func init() {
- Conns = make(map[uint]*client)
- }
- // 发送消息
- func send(conn *websocket.Conn, uid uint) {
- pip := Conns[uid].pip
- defer func() {
- delete(Conns, uid)
- }()
- for {
- data := <-pip
- data.Save()
- msg, err := json.Marshal(data)
- if err != nil {
- fmt.Println("link closed")
- break
- }
- err = conn.WriteMessage(data.MsgType, msg)
- if err != nil {
- fmt.Println(err.Error())
- break
- }
- }
- }
- // 接收消息
- func reception(conn *websocket.Conn, uid uint) {
- defer func() {
- delete(Conns, uid)
- }()
- for {
- var msg model.Message
- err := conn.ReadJSON(&msg)
- if err != nil {
- fmt.Println(err.Error())
- break
- }
- msg.Save()
- if v, ok := Conns[msg.To]; ok {
- v.pip <- &msg
- } else {
- }
- }
- }
- func broadcast(data model.Message) {
- for _, conn := range Conns {
- msg, err := json.Marshal(data)
- if err != nil {
- fmt.Println(err.Error())
- break
- }
- err = conn.conn.WriteMessage(data.MsgType, msg)
- if err != nil {
- fmt.Println(err.Error())
- break
- }
- }
- }
- func heartBeat(uid uint) {
- defer func() {
- delete(Conns, uid)
- }()
- for {
- msg := model.Message{
- MsgType: 4,
- From: 0,
- To: uid,
- Time: uint(time.Now().Unix()),
- Content: "alive",
- }
- if v, ok := Conns[msg.To]; ok {
- v.pip <- &msg
- time.Sleep(time.Second)
- } else {
- }
- }
- }
- func Chat(w http.ResponseWriter, rq *http.Request, uid uint) {
- // 升级为websocket
- conn, err := UP.Upgrade(w, rq, nil)
- if err != nil {
- response.Fail(w, err.Error(), http.StatusUpgradeRequired)
- return
- }
- pip := make(chan *model.Message, 1024)
- client := client{
- conn: conn,
- pip: pip,
- }
- Conns[uid] = &client
- go send(conn, uid)
- go reception(conn, uid)
- go heartBeat(uid)
- response.Success(w, "success", nil)
- }
- // LinkToServer
- // @Tags 聊天模块
- // @Summary 与服务端进行websocket连接,请使用postman测试
- // @Success 200 {object} response.Response
- // @Router /chat [get]
- func LinkToServer(c *gin.Context) {
- uid, err := strconv.Atoi(c.Query("uid"))
- if err != nil {
- response.Fail(c.Writer, "failed", 500)
- return
- }
- Chat(c.Writer, c.Request, uint(uid))
- }
- // GetMsgFromPaged
- // @Tags 聊天模块
- // @Summary 获取未过期且来源为特定用户的聊天记录
- // @Success 200 {object} response.Response
- // @Param uid query string true "用户id"
- // @Router /chat/from [get]
- func GetMsgFromPaged(c *gin.Context) {
- id := c.Query("uid")
- uid, err := strconv.ParseUint(id, 10, 64)
- if err != nil {
- msg := err.Error()
- response.Fail(c.Writer, msg, 500)
- return
- }
- msgs, err := model.Message{}.GetFrom(uint(uid))
- if err != nil {
- response.Fail(c.Writer, err.Error(), 500)
- return
- }
- response.Success(c.Writer, "success", msgs)
- }
- // GetMsgToPaged
- // @Tags 聊天模块
- // @Summary 获取未过期且目标为特定用户的聊天记录
- // @Success 200 {object} response.Response
- // @Param uid query string true "用户id"
- // @Router /chat/to [get]
- func GetMsgToPaged(c *gin.Context) {
- id := c.Query("uid")
- uid, err := strconv.ParseUint(id, 10, 64)
- if err != nil {
- msg := err.Error()
- response.Fail(c.Writer, msg, 500)
- return
- }
- msgs, err := model.Message{}.GetTo(uint(uid))
- if err != nil {
- response.Fail(c.Writer, err.Error(), 500)
- return
- }
- response.Success(c.Writer, "success", msgs)
- }
|