chatcontroller.go 4.7 KB


  1. package controller
  2. import (
  3. "encoding/json"
  4. "fmt"
  5. "github.com/gin-gonic/gin"
  6. "github.com/gorilla/websocket"
  7. "net/http"
  8. "strconv"
  9. "time"
  10. "trading-go/model"
  11. "trading-go/response"
  12. )
  13. var UP = websocket.Upgrader{
  14. CheckOrigin: func(r *http.Request) bool {
  15. return true
  16. },
  17. ReadBufferSize: 1024,
  18. WriteBufferSize: 1024,
  19. }
  20. type client struct {
  21. conn *websocket.Conn
  22. pip chan *model.Message
  23. }
  24. var Conns map[uint]*client
  25. func init() {
  26. Conns = make(map[uint]*client)
  27. }
  28. // 发送消息
  29. func send(conn *websocket.Conn, uid uint) {
  30. pip := Conns[uid].pip
  31. defer func() {
  32. delete(Conns, uid)
  33. }()
  34. for {
  35. data := <-pip
  36. err := data.Save()
  37. if err != nil {
  38. fmt.Println(err.Error())
  39. break
  40. }
  41. msg, err := json.Marshal(data)
  42. if err != nil {
  43. fmt.Println("link closed")
  44. break
  45. }
  46. if data.MsgType == 1 {
  47. err = conn.WriteMessage(data.MsgType, msg)
  48. if err != nil {
  49. fmt.Println(err.Error())
  50. break
  51. }
  52. } else if data.MsgType == 2 {
  53. err = conn.WriteMessage(websocket.TextMessage, msg)
  54. if err != nil {
  55. fmt.Println(err.Error())
  56. break
  57. }
  58. }
  59. }
  60. }
  61. // 接收消息
  62. func reception(conn *websocket.Conn, uid uint) {
  63. defer func() {
  64. delete(Conns, uid)
  65. }()
  66. for {
  67. var msg model.Message
  68. err := conn.ReadJSON(&msg)
  69. if err != nil {
  70. fmt.Println(err.Error())
  71. break
  72. }
  73. err = msg.Save()
  74. if err != nil {
  75. fmt.Println(err.Error())
  76. break
  77. }
  78. if v, ok := Conns[msg.To]; ok {
  79. v.pip <- &msg
  80. } else {
  81. }
  82. }
  83. }
  84. func broadcast(data model.Message) {
  85. for _, conn := range Conns {
  86. msg, err := json.Marshal(data)
  87. if err != nil {
  88. fmt.Println(err.Error())
  89. break
  90. }
  91. err = conn.conn.WriteMessage(data.MsgType, msg)
  92. if err != nil {
  93. fmt.Println(err.Error())
  94. break
  95. }
  96. }
  97. }
  98. func heartBeat(conn *websocket.Conn, uid uint) {
  99. defer func() {
  100. delete(Conns, uid)
  101. }()
  102. for {
  103. msg := model.Message{
  104. MsgType: 4,
  105. From: 0,
  106. To: uid,
  107. Time: uint(time.Now().Unix()),
  108. Content: "alive",
  109. }
  110. data, err := json.Marshal(msg)
  111. if err != nil {
  112. fmt.Println(err.Error())
  113. break
  114. }
  115. err = conn.WriteMessage(websocket.TextMessage, data)
  116. if err != nil {
  117. fmt.Println(err.Error())
  118. break
  119. }
  120. time.Sleep(time.Second * 5)
  121. }
  122. }
  123. func Chat(w http.ResponseWriter, rq *http.Request, uid uint) {
  124. // 升级为websocket
  125. conn, err := UP.Upgrade(w, rq, nil)
  126. if err != nil {
  127. response.Fail(w, err.Error(), http.StatusUpgradeRequired)
  128. return
  129. }
  130. pip := make(chan *model.Message, 1024)
  131. client := client{
  132. conn: conn,
  133. pip: pip,
  134. }
  135. Conns[uid] = &client
  136. go send(conn, uid)
  137. go reception(conn, uid)
  138. go heartBeat(conn, uid)
  139. response.Success(w, "success", nil)
  140. }
  141. // LinkToServer
  142. // @Tags 聊天模块
  143. // @Summary 与服务端进行websocket连接,请使用postman测试
  144. // @Success 200 {object} response.Response
  145. // @Router /chat [get]
  146. func LinkToServer(c *gin.Context) {
  147. uid, err := strconv.Atoi(c.Query("uid"))
  148. if err != nil {
  149. response.Fail(c.Writer, "failed", 500)
  150. return
  151. }
  152. Chat(c.Writer, c.Request, uint(uid))
  153. }
  154. // GetMsgFromPaged
  155. // @Tags 聊天模块
  156. // @Summary 获取未过期且来源为特定用户的聊天记录
  157. // @Success 200 {object} response.Response
  158. // @Param page path int true "页数"
  159. // @Param pageSize path int true "一页的大小"
  160. // @Param uid query string true "用户id"
  161. // @Router /chat/from/{page}/{pageSize} [get]
  162. func GetMsgFromPaged(c *gin.Context) {
  163. id := c.Query("uid")
  164. p := c.Param("page")
  165. pS := c.Param("pageSize")
  166. page, err := strconv.Atoi(p)
  167. pageSize, err := strconv.Atoi(pS)
  168. if err != nil {
  169. response.Fail(c.Writer, err.Error(), 500)
  170. return
  171. }
  172. uid, err := strconv.ParseUint(id, 10, 64)
  173. if err != nil {
  174. msg := err.Error()
  175. response.Fail(c.Writer, msg, 500)
  176. return
  177. }
  178. rsp, err := model.Message{}.GetFrom(uint(uid), page, pageSize)
  179. if err != nil {
  180. response.Fail(c.Writer, err.Error(), 500)
  181. return
  182. }
  183. response.Success(c.Writer, "success", rsp)
  184. }
  185. // GetMsgToPaged
  186. // @Tags 聊天模块
  187. // @Summary 获取未过期且目标为特定用户的聊天记录
  188. // @Param page path int true "页数"
  189. // @Param pageSize path int true "一页的大小"
  190. // @Param uid query string true "用户id"
  191. // @Success 200 {object} response.Response
  192. // @Router /chat/to/{page}/{pageSize} [get]
  193. func GetMsgToPaged(c *gin.Context) {
  194. id := c.Query("uid")
  195. p := c.Param("page")
  196. pS := c.Param("pageSize")
  197. page, err := strconv.Atoi(p)
  198. pageSize, err := strconv.Atoi(pS)
  199. if err != nil {
  200. response.Fail(c.Writer, err.Error(), 500)
  201. return
  202. }
  203. uid, err := strconv.ParseUint(id, 10, 64)
  204. if err != nil {
  205. msg := err.Error()
  206. response.Fail(c.Writer, msg, 500)
  207. return
  208. }
  209. rsp, err := model.Message{}.GetTo(uint(uid), page, pageSize)
  210. if err != nil {
  211. response.Fail(c.Writer, err.Error(), 500)
  212. return
  213. }
  214. response.Success(c.Writer, "success", rsp)
  215. }