package cn.quantgroup.xyqb.aspect.fplock;

import cn.quantgroup.xyqb.exception.ResubmissionException;
import cn.quantgroup.xyqb.model.JsonResult;
import org.apache.commons.collections.FastHashMap;
import org.apache.commons.lang3.time.DateFormatUtils;
import org.aspectj.lang.ProceedingJoinPoint;
import org.aspectj.lang.annotation.Around;
import org.aspectj.lang.annotation.Aspect;
import org.aspectj.lang.annotation.Pointcut;
import org.aspectj.lang.reflect.MethodSignature;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.beans.factory.annotation.Qualifier;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.core.Ordered;
import org.springframework.core.annotation.Order;
import org.springframework.data.redis.core.RedisTemplate;
import org.springframework.stereotype.Component;

import java.lang.reflect.Method;
import java.util.Calendar;
import java.util.HashMap;
import java.util.concurrent.TimeUnit;
import java.util.stream.Stream;

/**
 * 第一参数锁
 * Created by Miraculous on 15/11/10.
 */
@Aspect
@Component
@Order(value = Ordered.HIGHEST_PRECEDENCE)
public class FirstParamLockAspect {

    private static final Logger LOGGER = LoggerFactory.getLogger(FirstParamLockAspect.class);
    private static final HashMap<String, FPRestrictionWraper> LOCK_PARAM;
    private static final Long MAX_COUNTER = 1L;
    private static final Long MAX_TO_LIVE = 10L;

    static {
        FastHashMap fastHashMap = new FastHashMap();
        fastHashMap.setFast(true);
        LOCK_PARAM = fastHashMap;
    }

    @Autowired
    @Qualifier("stringRedisTemplate")
    private RedisTemplate<String, String> stringRedisTemplate;
    /* 自定义限制策略 (FPRestriction) 的情况下, 可配置的限制数值 */
    @Value("${xyqb.fplock.limit.byhour:3}")
    private Integer limitByHour;                // 每小时限制值
    @Value("${xyqb.fplock.limit.byday:5}")
    private Integer limitByDay;                 // 每天限制值

    @Pointcut("@annotation(cn.quantgroup.xyqb.aspect.fplock.FPLock)")
    private void fplockPointCut() {

    }

    @Around("fplockPointCut()")
    private Object preventDuplicateSubmit(ProceedingJoinPoint pjp) throws Throwable {
        Object[] args = pjp.getArgs();
        if (args == null || args.length == 0) {
            return pjp.proceed();
        }

        MethodSignature methodSignature = (MethodSignature) pjp.getSignature();
        Method method = methodSignature.getMethod();
        FPLock fpLock = method.getAnnotation(cn.quantgroup.xyqb.aspect.fplock.FPLock.class);
        Object fp = args[0];

        if (fpLock.restrictions().length < 1) {
            String lockName = fpLock.uniqueName() + fp.toString();
            Long ret = stringRedisTemplate.opsForValue().increment(lockName, 1L);
            if (MAX_COUNTER < ret) {
                stringRedisTemplate.expire(lockName, MAX_TO_LIVE, TimeUnit.SECONDS);
                throw new ResubmissionException();
            }
            try {
                return pjp.proceed();
            } finally {
                stringRedisTemplate.expire(lockName, MAX_TO_LIVE, TimeUnit.SECONDS);
            }
        } else {
            String uniqueName = fpLock.uniqueName();
            /* 加入自定义限制策略 */

            FPRestrictionWraper[] restrictionWrapers = wrapFPRestrictions(uniqueName, fpLock.restrictions());
            // 执行检查
            boolean restrictionsPass = preRestrictionsCheck(uniqueName, String.valueOf(fp), restrictionWrapers);
            if (restrictionsPass) {
                // 检查通过
                try {
                    Object proceed = pjp.proceed();
                    if (JsonResult.class == proceed.getClass()) {
                        JsonResult result = (JsonResult) proceed;
                        if (!"0000".equals(result.getCode()) || !"0000".equals(result.getBusinessCode())) {
                            // 业务操作失败
                            postRestrictionsCheck(uniqueName, String.valueOf(fp), restrictionWrapers);
                        }
                    }
                    return proceed;
                } catch (Exception e) {
                    // 业务操作异常
                    postRestrictionsCheck(uniqueName, String.valueOf(fp), restrictionWrapers);
                    throw e;
                }
            } else {
                // 检查不通过
                LOGGER.warn("接口首参数保护! api:{}, variable:{}", method.getName(), String.valueOf(fp));
                postRestrictionsCheck(uniqueName, String.valueOf(fp), restrictionWrapers);
                throw new ResubmissionException();
            }
        }
    }

    public void parseFpLockAnnotaion(Method method) {
        FPLock fpLock = method.getAnnotation(cn.quantgroup.xyqb.aspect.fplock.FPLock.class);
        wrapFPRestrictions(fpLock.uniqueName(), fpLock.restrictions());
    }

    /**
     * 重设限制参数
     */
    public void setLimitation(String key, Integer duration, Integer limit) {
        FPRestrictionWraper restriction = LOCK_PARAM.get(key);
        if (restriction != null) {
            if (duration != null) restriction.duration(duration);
            if (limit != null) restriction.limit(limit);
        }

    }

    /**
     * 读取限制参数
     */
    public FPRestrictionWraper readLimitation(String key) {
        return LOCK_PARAM.get(key);
    }

    /**
     * 包装注解
     *
     * @param uniqueName
     * @param restrictions
     * @return
     */
    private FPRestrictionWraper[] wrapFPRestrictions(String uniqueName, FPRestriction[] restrictions) {
        FPRestrictionWraper[] wrapers = new FPRestrictionWraper[restrictions.length];
        for (int i = 0; i < wrapers.length; i++) {
            FPRestriction restriction = restrictions[i];
            // 将restriction 配置参数读入 map
            String key = uniqueName + restriction.type().toString();
            if (LOCK_PARAM.containsKey(key)) {
                wrapers[i] = LOCK_PARAM.get(key);
            } else {
                FPRestrictionWraper wraper = new FPRestrictionWraper(restriction);
                wrapers[i] = wraper;
                LOCK_PARAM.put(key, wraper);
            }
        }
        return wrapers;
    }

    /**
     * 前置检查, 根据用户设置的锁策略进行锁检查
     *
     * @param uniqeName
     * @param variable
     * @param restrictions
     * @return
     */
    private boolean preRestrictionsCheck(String uniqeName, String variable, FPRestrictionWraper[] restrictions) {
        for (FPRestrictionWraper restriction : restrictions) {
            TimeUnit timeUnit = restriction.timeUnit();
            int duration = restriction.duration();
            int limit = restriction.limit();
            String lockName = uniqeName + restriction.getLockKey() + ":" + variable;
            // 根据锁名获取缓存的:已操作的次数
            Long action = stringRedisTemplate.opsForValue().increment(lockName, 1L);
            restriction.setProceed(true);
            if (action > limit) {
                // 已超过规定值, 本次操作不允许
                return false;
            }
            restriction.setSuccess(true);
            // 本条限制通过, 设置过期策略, 等待下次操作请求
            stringRedisTemplate.expire(lockName, duration, timeUnit);
        }

        return true;
    }

    /**
     * 后置检查, 用户锁值操作已成功, 但是业务操作失败的情况, 回滚 redis 计数
     *
     * @param uniqeName
     * @param variable
     * @param restrictions
     * @return
     */
    private boolean postRestrictionsCheck(String uniqeName, String variable, FPRestrictionWraper[] restrictions) {
        Stream<FPRestrictionWraper> restrictions1 = Stream.of(restrictions);
        restrictions1.filter(FPRestrictionWraper::isProceed).forEach(res -> {
            // 由于业务操作失败, "回滚"计数值
            String lockName = uniqeName + res.getLockKey() + ":" + variable;
            stringRedisTemplate.opsForValue().increment(lockName, -1L);
        });
        return true;
    }

    /**
     * 根据时间单位获取Redis锁名
     *
     * @return
     */
    private String getLockKeyByTimeUnit(TimeUnit timeUnit) {
        Calendar calendar = Calendar.getInstance();
        String unit = timeUnit.toString();
        switch (timeUnit) {
            case DAYS:
                return unit + DateFormatUtils.format(calendar, "yyyyMMdd");
            case HOURS:
                return unit + DateFormatUtils.format(calendar, "yyyyMMddHH");
            case MINUTES:
                return unit + DateFormatUtils.format(calendar, "yyyyMMddHHmm");
            case SECONDS:
                return unit + DateFormatUtils.format(calendar, "yyyyMMddHHmmss");
            default:
                return unit;
        }
    }

    /**
     * 注解包装类
     */
    public class FPRestrictionWraper {
        private FPRestriction restriction;
        private int limit;
        private int duration;
        private TimeUnit timeUnit;
        private boolean isSuccess;
        private boolean isProceed;

        public FPRestrictionWraper(FPRestriction restriction) {
            this.restriction = restriction;
            this.timeUnit = restriction.type();
            this.limit = restriction.limit();
            this.duration = restriction.duration();
        }

        public TimeUnit timeUnit() {
            return this.timeUnit;
        }

        public int duration() {
            return this.duration;
        }

        public void duration(int duration) {
            this.duration = duration;
        }

        public int limit() {
            return this.limit;
/*            int _limit = this.restriction.limit();
            switch (restriction.type()) {
                case DAYS:
                    return _limit == limitByDay ? _limit : limitByHour;
                case HOURS:
                    return _limit == limitByHour ? _limit : limitByDay;
                default:
                    return -1;
            }*/
        }

        public void limit(int limit) {
            this.limit = limit;
        }

        public boolean isSuccess() {
            return isSuccess;
        }

        public void setSuccess(boolean success) {
            isSuccess = success;
        }

        public boolean isProceed() {
            return isProceed;
        }

        public void setProceed(boolean proceed) {
            isProceed = proceed;
        }

        public String getLockKey() {
            return getLockKeyByTimeUnit(timeUnit());
        }

        public String desc() {
            return "timeUnit:" + timeUnit + ", duration:" + this.duration + ", limit:" + limit;
        }
    }
}
