package cn.quantgroup.big.stms.common.context;

import cn.quantgroup.big.stms.common.annotation.EnumKey;
import cn.quantgroup.big.stms.common.enums.BaseEnum;
import cn.quantgroup.big.stms.common.exception.BizException;
import cn.quantgroup.big.stms.common.result.ResultCode;
import org.reflections.Reflections;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.util.HashMap;
import java.util.Map;
import java.util.Set;

@SuppressWarnings("rawtypes")
public class EnumClassFactory {
    private final static Logger logger = LoggerFactory.getLogger(EnumClassFactory.class);

    private static Map<String, Class<? extends BaseEnum>> enums = new HashMap<>();

    static {
        Reflections reflections = new Reflections("cn.quantgroup.big.stms");

        Set<Class<? extends BaseEnum>> enumsClasses = reflections.getSubTypesOf(BaseEnum.class);

        for (Class<? extends BaseEnum> it : enumsClasses) {
            enums.put(it.getName(), it);

            String key = it.getSimpleName().toLowerCase();
            EnumKey enumKey = it.getAnnotation(EnumKey.class);
            if (null != enumKey) {
                key = enumKey.value();
            }

            if (!enums.containsKey(key)) {
                enums.put(key, it);
            }
        }

        logger.info("枚举类已注册到枚举工厂 ");
    }

    public static Class<? extends BaseEnum> getEnumClass(String name) {
        Class<? extends BaseEnum> clazz = enums.get(name);
        if (null == clazz) {
            throw new BizException("请求的枚举类不存在", ResultCode.PARAM_ERROR);
        }
        return clazz;
    }
}
