import matplotlib.pyplot as plt
import pandas as pd
from sklearn.metrics import roc_auc_score
from models_kit import lightgbm
from models_kit import xgboost

def topN_feature_importance_plot(model, clf, title="untitled", save_path='./mvp/plots/', topN=20):
    '''
    plot feature importance squence
    params:
        classifier
    '''
    plt.rcParams['font.sans-serif'] = ['SimHei']
    plt.rcParams['axes.unicode_minus'] = False
    plt.rcParams['savefig.dpi'] = 226  # 图片像素
    plt.rcParams['figure.dpi'] = 200  # 分辨率
    plt.figure(figsize=(10, 6))
    model.plot_importance(clf, max_num_features=topN)
    plt.title("Feature Importances")

    path = save_path + title + "_featureImportance.png"
    plt.savefig(path)
    plt.show()
    return path


def topN_feature_importance_list(features, clf, topN=3):
    '''
    instructions : return topN_feature_importance dataframe
    :param features:
    :param clf:
    :param topN:
    :return:
    '''
    importanct_feat = pd.DataFrame({
        'column': features,
        'importance': clf.feature_importance(),
    }).sort_values(by='importance', ascending=False).column.tolist()[:3]
    return importanct_feat


def model_selection(algorthm,clf,df_train,df_val,df_test,target,score,optimal_model,model_obj):
    # model matrix 存储不同模型指标的矩阵
    model_matrix_index = ['name', 'Params', 'trainAUC', 'validationAUC']
    model_matrix = pd.DataFrame(['NULL', 'NULL', roc_auc_score(df_train[target], df_train[score]),
                                 roc_auc_score(df_train[target], df_train[score])], index=model_matrix_index,
                                columns=['线上模型'])

    # 定义最优参指针
    pointer = 0
    # 遍历最优参组合
    for param in optimal_para:
        if algorthm == "lightGBM":
            train_auc, val_auc, lgbm = lightgbm.train_lgbm(lightgbm.params_lgb, df_train, df_val, model_obj.features,
                                                       adds_on=param, target=target)
        model_matrix = pd.concat([model_matrix,
                                  pd.DataFrame(['lightGBM', param, train_auc, val_auc], index=model_matrix_index,
                                               columns=[pointer])], axis=1)
        pointer += 1

    # 简单选取一下validation set auc 最高的 params
    best_params = model_matrix.T.sort_values(by='validationAUC', ascending=False).iloc[0, :].loc['Params']