package cn.quantgroup.handler;

import cn.quantgroup.model.MsgAgreement;
import cn.quantgroup.server.IStmsServer;
import cn.quantgroup.utils.Md5Utils;
import com.alibaba.fastjson.JSON;
import cn.quantgroup.model.DeviceChannelInfo;
import cn.quantgroup.server.CacheService;
import cn.quantgroup.store.WebSocketSession;
import cn.quantgroup.utils.CacheUtil;
import cn.quantgroup.utils.NetWorkUtils;
import io.netty.buffer.ByteBufUtil;
import io.netty.channel.Channel;
import io.netty.channel.ChannelFuture;
import io.netty.channel.ChannelFutureListener;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.socket.SocketChannel;
import io.netty.handler.codec.http.*;
import io.netty.handler.codec.http.websocketx.*;
import io.netty.util.AttributeKey;
import lombok.extern.slf4j.Slf4j;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.data.redis.core.StringRedisTemplate;
import org.springframework.stereotype.Component;

import java.util.Date;
import java.util.List;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ConcurrentMap;
import java.util.stream.Collectors;

/**
 * Websocket 消息处理器
 *
 * @author qiding
 */
@Slf4j
@Component
public class WebsocketMessageHandler {

    private final String USER = "user";
    private final AttributeKey<String> key = AttributeKey.valueOf(USER);
    private CacheService cacheService;
    private IStmsServer stmsServer;

    public WebsocketMessageHandler(CacheService cacheService, IStmsServer stmsServer) {
        this.cacheService = cacheService;
        this.stmsServer = stmsServer;
    }

    @Autowired
    private StringRedisTemplate redisTemplate;

    /**
     * 对webSocket 首次握手进行解析
     */
    public void handleHttpRequest(ChannelHandlerContext ctx, FullHttpRequest request) throws Exception {
        // 首次握手进行校验
        this.isFullHttpRequest(ctx, request);
        // 获取请求uri
        String uri = request.uri();
        ConcurrentMap<String, String> paramMap = getUrlParams(uri);
        System.out.println("接收到的参数是：" + JSON.toJSONString(paramMap));
        // 加校验
        if (paramMap.getOrDefault("token", "").equals("") || paramMap.getOrDefault("channelId", "").equals("")) {
            this.sendResponse(ctx, request, new DefaultFullHttpResponse(request.protocolVersion(), HttpResponseStatus.BAD_REQUEST, ctx.alloc().buffer()));
            ctx.close();
        }
        String supplierCode = stmsServer.getStmsTokenInfo(paramMap.get("token"));
//        String supplierCode = "100";
        if (null == supplierCode) {
            this.sendResponse(ctx, request, new DefaultFullHttpResponse(request.protocolVersion(), HttpResponseStatus.FORBIDDEN, ctx.alloc().buffer()));
            ctx.close();
            return;
        }
        String mdString = paramMap.get("channelId") + "-" + supplierCode;
        String relationKey = Md5Utils.MD5Encode(mdString);
        online(relationKey, ctx.channel());
        // 参数分别是 (ws地址,子协议,是否扩展,最大frame长度)
        WebSocketServerHandshakerFactory factory = new WebSocketServerHandshakerFactory(getWebSocketLocation(request), null, true, 5 * 1024 * 1024);
        WebSocketServerHandshaker handShaker = factory.newHandshaker(request);
        if (handShaker == null) {
            WebSocketServerHandshakerFactory.sendUnsupportedVersionResponse(ctx.channel());
        } else {
            handShaker.handshake(ctx.channel(), request);
        }
        WebSocketSession.setChannelShaker(ctx.channel().id(), handShaker);
        SocketChannel channel = (SocketChannel) ctx.channel();
        //保存设备信息
        DeviceChannelInfo deviceChannelInfo = DeviceChannelInfo.builder()
                .channelId(channel.id().toString())
                .ip(NetWorkUtils.getHost())
                .port(channel.localAddress().getPort())
                .linkDate(new Date())
                .relationInfo(relationKey)
                .build();
        cacheService.getRedisUtil().pushObj(deviceChannelInfo);
        CacheUtil.cacheChannel.put(channel.id().toString(), channel);
    }

    /**
     * 处理消息
     */
    public void handleWebSocketFrame(ChannelHandlerContext ctx, WebSocketFrame frame) {
        // 获取webSocket 会话
        WebSocketServerHandshaker handShaker = WebSocketSession.getChannelShaker(ctx.channel().id());
        // 关闭
        if (frame instanceof CloseWebSocketFrame) {
            log.debug("收到关闭请求");
            handShaker.close(ctx.channel(), (CloseWebSocketFrame) frame.retain());
            return;
        }
        // 握手PING/PONG
        if (frame instanceof PingWebSocketFrame) {
            ctx.writeAndFlush(new PongWebSocketFrame(frame.content().retain()));
            return;
        }
        // 文本接收和回复
        if (frame instanceof TextWebSocketFrame) {
            String msgContent = ((TextWebSocketFrame) frame).text();
            if ("HeartBeat".equals(msgContent)) {
                log.debug("心跳消息");
                return;
            }
            log.debug("收到消息：\n{}", msgContent);
            //判断接收消息用户是否在本服务端
            Channel channel = CacheUtil.cacheChannel.get(ctx.channel().id().toString());
            MsgAgreement msgAgreement = new MsgAgreement();
            msgAgreement.setToChannelId(ctx.channel().id().toString());
            msgAgreement.setContent(msgContent);
            if (null != channel) {
                channel.writeAndFlush(new TextWebSocketFrame(msgContent));
                return;
            }
            //如果为NULL则接收消息的用户不在本服务端，需要push消息给全局
            cacheService.push(msgAgreement);
            return;
        }
        // 二进制文本
        if (frame instanceof BinaryWebSocketFrame) {
            ctx.writeAndFlush(frame.retain());
        }
    }

    /**
     * 根据用户id查找channel
     *
     * @param name
     * @return
     */
    public List<Channel> getChannelByName(String name) {
        return ChannelHandlerPool.channelGroup.stream().filter(channel -> channel.attr(key).get().equals(name))
                .collect(Collectors.toList());
    }

    /**
     * 判断是否是正确的websocket 握手协议
     */
    private void isFullHttpRequest(ChannelHandlerContext ctx, FullHttpRequest request) {
        if (!request.decoderResult().isSuccess()) {
            log.error("非webSocket请求");
            this.sendResponse(ctx, request, new DefaultFullHttpResponse(request.protocolVersion(), HttpResponseStatus.BAD_REQUEST, ctx.alloc().buffer()));
            ctx.close();
            return;
        }
        if (!HttpMethod.GET.equals(request.method())) {
            log.error("非GET请求");
            this.sendResponse(ctx, request, new DefaultFullHttpResponse(request.protocolVersion(), HttpResponseStatus.FORBIDDEN, ctx.alloc().buffer()));
            ctx.close();
        }
    }

    /**
     * SSL支持采用wss:
     */
    private String getWebSocketLocation(FullHttpRequest request) {
        return "ws://" + request.headers().get(HttpHeaderNames.HOST) + "/websocket";
    }


    /**
     * http 握手通用响应
     */
    private void sendResponse(ChannelHandlerContext ctx, FullHttpRequest req, FullHttpResponse resp) {
        HttpResponseStatus status = resp.status();
        if (status != HttpResponseStatus.OK) {
            ByteBufUtil.writeUtf8(resp.content(), status.toString());
            HttpUtil.setContentLength(req, resp.content().readableBytes());
        }
        boolean keepAlive = HttpUtil.isKeepAlive(req) && status == HttpResponseStatus.OK;
        HttpUtil.setKeepAlive(req, keepAlive);
        ChannelFuture future = ctx.write(resp);
        if (!keepAlive) {
            future.addListener(ChannelFutureListener.CLOSE);
        }
    }

    /**
     * 上线一个用户
     *
     * @param channel
     * @param relationKey
     */
    private void online(String relationKey, Channel channel) {
        // 保存channel通道的附带信息，以用户的uid为标识
//        channel.attr(key).set(token);
//        ChannelHandlerPool.channelGroup.add(channel);
        System.out.println(relationKey);
        cacheService.getRedisUtil().pushChannelRelation(channel.id().toString(), relationKey);
    }

    private static ConcurrentMap<String, String> getUrlParams(String url) {
        ConcurrentMap<String, String> map = new ConcurrentHashMap<>();
        url = url.replace("?", ";");
        if (!url.contains(";")) {
            return map;
        }
        if (url.split(";").length > 0) {
            String[] arr = url.split(";")[1].split("&");
            for (String s : arr) {
                String key = s.split("=")[0];
                String value = s.split("=")[1];
                map.put(key, value);
            }
            return map;

        } else {
            return map;
        }
    }
}
