package cn.quantgroup.tech.db;

import lombok.extern.slf4j.Slf4j;
import org.springframework.beans.MutablePropertyValues;
import org.springframework.beans.factory.support.BeanDefinitionRegistry;
import org.springframework.beans.factory.support.GenericBeanDefinition;
import org.springframework.boot.jdbc.DataSourceBuilder;
import org.springframework.context.EnvironmentAware;
import org.springframework.context.annotation.ImportBeanDefinitionRegistrar;
import org.springframework.core.convert.ConversionService;
import org.springframework.core.convert.support.DefaultConversionService;
import org.springframework.core.env.Environment;
import org.springframework.core.type.AnnotationMetadata;

import javax.sql.DataSource;
import java.util.HashMap;
import java.util.Map;

/**
 * 使用 EnableDynamicDataSource 注解会加载这个类
 *
 * @author ag
 */
@Slf4j
public class DynamicDataSourceRegister implements ImportBeanDefinitionRegistrar, EnvironmentAware {

    private ConversionService conversionService = new DefaultConversionService();
    /**
     * 默认的数据lter库连接池类型
     */
    private static final String DATASOURCE_TYPE_DEFAULT = "com.zaxxer.hikari.HikariDataSource";
    /**
     * 默认的数据库驱动类型
     */
    private static final String DATABASE_TYPE_DEFAULT = "com.mysql.jdbc.Driver";

    private static final String MASTER_PREFIX = "spring.datasource.";
    private static final String SLAVE_PREFIX = "slave.datasource.";

    private DataSource masterDataSource;
    private DataSource slaveDataSource;

    @Override
    public void setEnvironment(Environment environment) {
        //1. 初始化配置. 2. 构建数据源. 3. 设置数据源属性
        masterDataSource = buildDataSource(initDataSource(environment, MASTER_PREFIX));

        slaveDataSource = buildDataSource(initDataSource(environment, SLAVE_PREFIX));
    }

    @Override
    public void registerBeanDefinitions(AnnotationMetadata annotationMetadata, BeanDefinitionRegistry beanDefinitionRegistry) {

        Map<Object, Object> targetDataSources = new HashMap<>(1);
        // 添加主数据源
        targetDataSources.put(DSType.MASTER, masterDataSource);
        DynamicDataSourceContextHolder.dataSourceIds.add(DSType.MASTER);
        // 添加更多数据源
        targetDataSources.put(DSType.SLAVE, slaveDataSource);
        DynamicDataSourceContextHolder.dataSourceIds.add(DSType.SLAVE);

        // 创建DynamicDataSource
        GenericBeanDefinition beanDefinition = new GenericBeanDefinition();
        beanDefinition.setBeanClass(DynamicDataSource.class);
        beanDefinition.setSynthetic(true);
        MutablePropertyValues mpv = beanDefinition.getPropertyValues();
        mpv.addPropertyValue("defaultTargetDataSource", masterDataSource);
        mpv.addPropertyValue("targetDataSources", targetDataSources);
        beanDefinitionRegistry.registerBeanDefinition("dataSource", beanDefinition);

        log.info("Dynamic DataSource Registry Success");
    }

    private Map<String, Object> initDataSource(Environment env, String prefix) {
        Map<String, Object> dsMap = new HashMap<>(1);
        //类型
        dsMap.put("type", env.getProperty(prefix + "type", DATASOURCE_TYPE_DEFAULT));
        //默认也就是Mysql. 没其他的
        dsMap.put("driver-class-name", env.getProperty(prefix + "driver-class-name", DATABASE_TYPE_DEFAULT));
        dsMap.put("url", env.getProperty(prefix + "url"));
        dsMap.put("username", env.getProperty(prefix + "username"));
        dsMap.put("password", env.getProperty(prefix + "password"));
        return dsMap;
    }

    /**
     * 根据已知信息构造DataSource
     *
     * @param dsMap
     * @return
     */
    private DataSource buildDataSource(Map<String, Object> dsMap) {
        try {
            String type = dsMap.get("type").toString();

            Class<? extends DataSource> dataSourceType = (Class<? extends DataSource>) Class.forName(type);

            String driverClassName = dsMap.get("driver-class-name").toString();
            String url = dsMap.get("url").toString();
            String username = dsMap.get("username").toString();
            String password = dsMap.get("password").toString();

            DataSourceBuilder factory = DataSourceBuilder.create().driverClassName(driverClassName).url(url)
                    .username(username).password(password).type(dataSourceType);
            return factory.build();
        } catch (ClassNotFoundException e) {
            log.error("找不到数据源..你没配置hikari的包么? :{}", e.getMessage());
        }
        return null;
    }


}
