package controller import ( "encoding/json" "fmt" "github.com/gin-gonic/gin" "github.com/gorilla/websocket" "net/http" "strconv" "trading-go/model" "trading-go/response" ) var UP = websocket.Upgrader{ ReadBufferSize: 1024, WriteBufferSize: 1024, } type client struct { conn *websocket.Conn pip chan *model.Message } var Conns map[uint]*client func init() { Conns = make(map[uint]*client) } // 发送消息 func send(conn *websocket.Conn, uid uint) { pip := Conns[uid].pip defer func() { delete(Conns, uid) }() for { data := <-pip msg, err := json.Marshal(data) if err != nil { fmt.Println("link closed") break } err = conn.WriteMessage(data.MsgType, msg) if err != nil { fmt.Println(err.Error()) break } } } // 接收消息 func reception(conn *websocket.Conn, uid uint) { defer func() { delete(Conns, uid) }() for { var msg model.Message err := conn.ReadJSON(&msg) if err != nil { fmt.Println(err.Error()) break } if v, ok := Conns[msg.To]; ok { v.pip <- &msg } else { } } } func broadcast(data model.Message) { for _, conn := range Conns { msg, err := json.Marshal(data) if err != nil { fmt.Println(err.Error()) break } err = conn.conn.WriteMessage(data.MsgType, msg) if err != nil { fmt.Println(err.Error()) break } } } func Chat(w http.ResponseWriter, rq *http.Request, uid uint) { // 升级为websocket conn, err := UP.Upgrade(w, rq, nil) if err != nil { response.Fail(w, err.Error(), http.StatusUpgradeRequired) return } pip := make(chan *model.Message, 1024) client := client{ conn: conn, pip: pip, } Conns[uid] = &client go send(conn, uid) go reception(conn, uid) response.Success(w, "success", nil) } // LinkToServer // @Tags 聊天模块 // @Summary 与服务端进行websocket连接,请使用postman测试 // @Success 200 {object} response.Response // @Router /chat [get] func LinkToServer(c *gin.Context) { uid, err := strconv.Atoi(c.Query("uid")) if err != nil { response.Fail(c.Writer, "failed", 500) return } Chat(c.Writer, c.Request, uint(uid)) }