package controller

import (
	"encoding/json"
	"fmt"
	"github.com/gin-gonic/gin"
	"github.com/gorilla/websocket"
	"net/http"
	"strconv"
	"time"
	"trading-go/model"
	"trading-go/response"
)

var UP = websocket.Upgrader{
	CheckOrigin: func(r *http.Request) bool {
		return true
	},
	ReadBufferSize:  1024,
	WriteBufferSize: 1024,
}

type client struct {
	conn *websocket.Conn
	pip  chan *model.Message
}

var Conns map[uint]*client

func init() {
	Conns = make(map[uint]*client)
}

// 发送消息
func send(conn *websocket.Conn, uid uint) {
	pip := Conns[uid].pip
	defer func() {
		delete(Conns, uid)
	}()
	for {
		data := <-pip
		err := data.Save()
		if err != nil {
			fmt.Println(err.Error())
			break
		}
		msg, err := json.Marshal(data)
		if err != nil {
			fmt.Println("link closed")
			break
		}
		if data.MsgType == 1 {
			err = conn.WriteMessage(data.MsgType, msg)
			if err != nil {
				fmt.Println(err.Error())
				break
			}
		} else if data.MsgType == 2 {
			err = conn.WriteMessage(websocket.TextMessage, msg)
			if err != nil {
				fmt.Println(err.Error())
				break
			}
		}
	}
}

// 接收消息
func reception(conn *websocket.Conn, uid uint) {
	defer func() {
		delete(Conns, uid)
	}()
	for {
		var msg model.Message
		err := conn.ReadJSON(&msg)
		if err != nil {
			fmt.Println(err.Error())
			break
		}
		err = msg.Save()
		if err != nil {
			fmt.Println(err.Error())
			break
		}
		if v, ok := Conns[msg.To]; ok {
			v.pip <- &msg
		} else {

		}
	}
}

//func broadcast(data model.Message) {
//	for _, conn := range Conns {
//		msg, err := json.Marshal(data)
//		if err != nil {
//			fmt.Println(err.Error())
//			break
//		}
//		err = conn.conn.WriteMessage(data.MsgType, msg)
//		if err != nil {
//			fmt.Println(err.Error())
//			break
//		}
//	}
//}

func heartBeat(conn *websocket.Conn, uid uint) {
	defer func() {
		delete(Conns, uid)
	}()
	for {
		msg := model.Message{
			MsgType: 4,
			From:    0,
			To:      uid,
			Time:    uint(time.Now().Unix()),
			Content: "alive",
		}
		data, err := json.Marshal(msg)
		if err != nil {
			fmt.Println(err.Error())
			break
		}
		err = conn.WriteMessage(websocket.TextMessage, data)
		if err != nil {
			fmt.Println(err.Error())
			break
		}
		time.Sleep(time.Second * 5)
	}
}

func Chat(w http.ResponseWriter, rq *http.Request, uid uint) {
	// 升级为websocket
	conn, err := UP.Upgrade(w, rq, nil)

	if err != nil {
		response.Fail(w, err.Error(), http.StatusUpgradeRequired)
		return
	}

	pip := make(chan *model.Message, 1024)
	client := client{
		conn: conn,
		pip:  pip,
	}
	Conns[uid] = &client

	go send(conn, uid)
	go reception(conn, uid)
	go heartBeat(conn, uid)

	response.Success(w, "success", nil)
}

// LinkToServer
// @Tags 聊天模块
// @Summary 与服务端进行websocket连接,请使用postman测试
// @Success 200 {object} response.Response
// @Router /chat [get]
func LinkToServer(c *gin.Context) {
	uid, err := strconv.Atoi(c.Query("uid"))
	if err != nil {
		response.Fail(c.Writer, "failed", 500)
		return
	}
	Chat(c.Writer, c.Request, uint(uid))
}

// GetMsgFromPaged
// @Tags 聊天模块
// @Summary 获取未过期且来源为特定用户的聊天记录
// @Success 200 {object} response.Response
// @Param page path int true "页数"
// @Param pageSize path int true "一页的大小"
// @Param uid query string true "用户id"
// @Router /chat/from/{page}/{pageSize} [get]
func GetMsgFromPaged(c *gin.Context) {
	id := c.Query("uid")
	p := c.Param("page")
	pS := c.Param("pageSize")
	page, err := strconv.Atoi(p)
	pageSize, err := strconv.Atoi(pS)
	if err != nil {
		response.Fail(c.Writer, err.Error(), 500)
		return
	}
	uid, err := strconv.ParseUint(id, 10, 64)
	if err != nil {
		msg := err.Error()
		response.Fail(c.Writer, msg, 500)
		return
	}
	rsp, err := model.Message{}.GetFrom(uint(uid), page, pageSize)
	if err != nil {
		response.Fail(c.Writer, err.Error(), 500)
		return
	}
	response.Success(c.Writer, "success", rsp)
}

// GetMsgToPaged
// @Tags 聊天模块
// @Summary 获取未过期且目标为特定用户的聊天记录
// @Param page path int true "页数"
// @Param pageSize path int true "一页的大小"
// @Param uid query string true "用户id"
// @Success 200 {object} response.Response
// @Router /chat/to/{page}/{pageSize} [get]
func GetMsgToPaged(c *gin.Context) {
	id := c.Query("uid")
	p := c.Param("page")
	pS := c.Param("pageSize")
	page, err := strconv.Atoi(p)
	pageSize, err := strconv.Atoi(pS)
	if err != nil {
		response.Fail(c.Writer, err.Error(), 500)
		return
	}
	uid, err := strconv.ParseUint(id, 10, 64)
	if err != nil {
		msg := err.Error()
		response.Fail(c.Writer, msg, 500)
		return
	}
	rsp, err := model.Message{}.GetTo(uint(uid), page, pageSize)
	if err != nil {
		response.Fail(c.Writer, err.Error(), 500)
		return
	}
	response.Success(c.Writer, "success", rsp)
}

// GetMsg
// @Tags 聊天模块
// @Summary 获取与特定两个用户之间的所有聊天记录
// @Param uid query string true "用户id"
// @Param target query string true "对象id"
// @Success 200 {object} response.Response
// @Router /chat/msg [get]
func GetMsg(c *gin.Context) {
	var ms model.Message
	id := c.Query("uid")
	uid, err := strconv.ParseUint(id, 10, 64)
	if err != nil {
		msg := err.Error()
		response.Fail(c.Writer, msg, 500)
		return
	}
	t := c.Query("target")
	target, err := strconv.ParseUint(t, 10, 64)
	if err != nil {
		msg := err.Error()
		response.Fail(c.Writer, msg, 500)
		return
	}
	msgs, err := ms.GetMsg(uint(uid), uint(target))
	if err != nil {
		msg := err.Error()
		response.Fail(c.Writer, msg, 500)
		return
	}
	response.Success(c.Writer, "success", msgs)
}

// GetMsgLatest
// @Tags 聊天模块
// @Summary 获取与特定用户相关的最后一条聊天记录
// @Param uid query string true "用户id"
// @Param target query string true "对象id"
// @Success 200 {object} response.Response
// @Router /chat/latest [get]
func GetMsgLatest(c *gin.Context) {
	var ms model.Message
	id := c.Query("uid")
	uid, err := strconv.ParseUint(id, 10, 64)
	if err != nil {
		msg := err.Error()
		response.Fail(c.Writer, msg, 500)
		return
	}
	t := c.Query("target")
	target, err := strconv.ParseUint(t, 10, 64)
	if err != nil {
		msg := err.Error()
		response.Fail(c.Writer, msg, 500)
		return
	}
	msg, err := ms.GetLatestMsg(uint(uid), uint(target))
	if err != nil {
		msg := err.Error()
		response.Fail(c.Writer, msg, 400)
		return
	}
	response.Success(c.Writer, "success", msg)
}