"""
Created on Thu Apr 18 11:32:06 2019

@author: Jason Wang
"""

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns

############# plot config ###############
plt.rcParams['font.sans-serif'] = ['SimHei']
plt.rcParams['axes.unicode_minus'] = False
plt.rcParams['savefig.dpi'] = 226  # 图片像素
plt.rcParams['figure.dpi'] = 200  # 分辨率


def topN_feature_importance(model, clf, title="untitled", save_path='./mvp/plots/', topN=20):
    '''
    plot feature importance squence
    params:
        classifier
    '''
    plt.figure(figsize=(10, 6))
    model.plot_importance(clf, max_num_features=topN)
    plt.title("Feature Importances")

    path = save_path + title + " featureImportance.png"
    plt.savefig(path)
    plt.show()
    return path


def plot_table_list(datalist, auc, datalist_description=None, title='untitled', X_label=None, y_label=None,
                    tab_df_list=None, plot_tab=True,
                    tab_rows=None, saved_path=None):
    '''
        instructions : visualization of pivot with given list of dataframe
        Params :
            dataset -
            auc - auc list / array
            title - title of plot('untitled' as default)
            x_label - X axis label of plot
            y_label - y axis label of plot
            plot_tab - plot table or not , default as True
            saved_path - saved path, set as None as there has no download needs
        '''
    fig, axs = plt.subplots(1, 1, figsize=(13, 9), linewidth=0.1)



    # datalist description
    if datalist_description is None:
        datalist_description = range(len(datalist))

    for table_index in range(len(datalist)):
        # 每个table需要只有一个index，一个values
        x = range(len(datalist[table_index].index))
        y = datalist[table_index].values
        axs.plot(x, y, label=datalist_description[table_index])
        if len(x) == 1:
            plot_tab = False
    if plot_tab:
        table_rows = []
        table_cols = range(len(datalist[table_index].index))
        tab_df = []
        if tab_df_list is None:
            for data in datalist:
                tab_df.append(
                    pd.Series(data.index).astype(str).map(lambda x: x.replace(' ', '').replace('0.', '.')).tolist())
                tab_df.append(
                    pd.Series(data.values).astype(str).map(lambda x: x.replace(' ', '').replace('0.', '.')).tolist())
                # validate tab_rows
                if tab_rows is None:
                    table_rows.append('index');
                    table_rows.append('values')
                else:
                    # tab_rows was given by
                    table_rows = table_rows + tab_rows
        else:
            for data in tab_df_list:
                tab_df.append(
                    pd.Series(data.index).astype(str).map(lambda x: x.replace(' ', '').replace('0.', '.')).tolist())
                tab_df.append(
                    pd.Series(data.values).astype(str).map(lambda x: x.replace(' ', '').replace('0.', '.')).tolist())
                # validate tab_rows
                if tab_rows is None:
                    table_rows.append('index')
                    table_rows.append('values')
                else:
                    # tab_rows was given by
                    table_rows = table_rows + tab_rows

        the_table = plt.table(cellText=tab_df,
                              rowLabels=table_rows,
                              colLabels=table_cols,
                              colWidths=[0.91 / (len(table_cols) - 1)] * len(table_cols),
                              loc='bottom')
        plt.xticks([])
    # otherwise, nothing to do here
        the_table.auto_set_font_size(False)
        the_table.set_fontsize(8)
    fig.subplots_adjust(bottom=0.2)
    plt.grid()
    if y_label is not None:
        plt.ylabel(y_label)
    if X_label is not None:
        plt.xlabel(X_label)
    plt.legend()
    # plt.vlines(xrange(len(cols))0],y,color='lightgrey',linestyle='--')
    plt.title(title)
    if saved_path is not None:
        plt.savefig(saved_path + title + ".png")
    plt.show()
    return 1


def plot_table_df(dataset, auc, title='untitled', X_label=None, y_label=None,
                  tab_df=None, plot_tab=True, saved_path=None):
    print(tab_df)
    '''
    instructions : visualization of pivot with single dataframe
    Params :
        dataset -
        auc - auc list / array
        title - title of plot('untitled' as default)
        x_label - X axis label of plot
        y_label - y axis label of plot
        plot_tab - plot table or not , default as True
        saved_path - saved path, set as None as there has no download needs
    '''
    fig, axs = plt.subplots(1, 1, figsize=(13, 9), linewidth=0.1)

    table_rows = dataset.columns
    table_cols = pd.Series(dataset.index).astype(str).map(lambda x: x.replace(' ', '')).map(
        lambda x : x.replace('0.', '.'))

    # traverse each columns of dataframe
    for i in range(len(table_rows)):
        x = range(len(table_cols))
        y = dataset.iloc[:, i]
        axs.plot(x, y, label=str(table_rows[i]) + ' AUC: ' + str(auc[i]))
    # if table should be plot
    if plot_tab:
        if tab_df is None:

            tab_df = [list(dataset.iloc[:, 1].values) for i in range(len(table_rows))]
        else:
            table_rows = tab_df.columns
            table_cols = tab_df.index
            tab_df = [list(tab_df.iloc[:, 1].values) for i in range(len(table_rows))]
        the_table = plt.table(cellText=tab_df,
                              rowLabels=table_rows,
                              colLabels=table_cols,
                              colWidths=[0.91 / (len(table_cols) - 1)] * len(table_cols),
                              loc='bottom')
        plt.xticks([])
    # otherwise, nothing to do here
    the_table.auto_set_font_size(False)
    the_table.set_fontsize(9)
    fig.subplots_adjust(bottom=0.2)
    plt.grid()
    if y_label is not None:
        plt.ylabel(y_label)
    if X_label is not None:
        plt.xlabel(X_label)
    plt.legend()
    # plt.vlines(xrange(len(cols))0],y,color='lightgrey',linestyle='--')
    plt.title(title)
    if saved_path is not None:
        plt.savefig(saved_path + title + ".png")
    plt.show()
    return 1






def plot_curve_singleCurve(dataset, x_label=None, y_label=None, table_tab=None,
                           save_path=None, figure_arrangement=11, fig_size=(4, 3),
                           fig_title='General Plot', fig_name='untitled',
                           fig_path=None):
    col = dataset.columns
    index = pd.Series(dataset.index.sort_values()).astype(str)
    plt.figure(figsize=fig_size)
    metric = figure_arrangement // 10 * figure_arrangement % 10

    for i in range(int(np.ceil(len(col) // metric))):

        cols = col[i * metric:]
        for fig_ith in range(len(cols)):
            axs = plt.subplot(figure_arrangement * 10 + 1 + fig_ith)
            axs.plot(index, dataset.loc[cols[fig_ith]])
            axs.set_title(cols[fig_ith], fontsize=7)
            plt.xticks(fontsize=5)
            plt.yticks(fontsize=5)
            plt.grid()

            if x_label != None:
                axs.set_xlabel(x_label, fontsize=5)
                if y_label != None:
                    axs.set_ylabel(y_label, fontsize=5)
        plt.tight_layout()
        plt.show()
    return 1


# fig,axs = plt.subplots(1,1,figsize=(16,9),linewidth=0.1)


#
# for fig_ith in range(len(df.columns)):
#    axs = plt.subplot(figure_arrangement * 10 + 1 + fig_ith)
#    axs.plot(df.index,df.iloc[fig_ith])
#    axs.set_title(col[])
# plt.tight_layout()

def plot_curve_multiCurve(dataset, x_label=None, y_label=None, table_tab=None,
                          save_path=None, figure_arrangement=11, fig_size=(4, 3),
                          fig_title='General Plot', fig_name='untitled',
                          fig_path=None):
    col = dataset.columns
    index = pd.Series(dataset.index.sort_values()).astype(str)
    plt.figure(figsize=fig_size)
    # metric = figure_arrangement // 10 * figure_arrangement % 10

    # cols = col[i * metric:]
    axs = plt.subplot(111)
    for fig_ith in range(len(col)):
        axs.plot(index, dataset.loc[col[fig_ith]], label=col[fig_ith])
    axs.set_title(col[fig_ith], fontsize=7)
    plt.xticks(fontsize=5)
    plt.yticks(fontsize=5)
    plt.grid()

    if x_label != None:
        axs.set_xlabel(x_label, fontsize=5)
    if y_label != None:
        axs.set_ylabel(y_label, fontsize=5)
    plt.legend()
    plt.tight_layout()
    plt.show()
    return 1


'''

'''


def plot_curve_mingle():
    return 1


def density_chart(dataset, title):
    for col in dataset.columns:
        sns.kdeplot(dataset.loc[:, col], label=col)
    plt.title(title)
    plt.show()

#
#	    alpha = 0.98 / 4 * fig_ith + 0.01
#	    ax.set_title('%.3f' % alpha)
#	    t1 = np.arange(0.0, 1.0, 0.01)
#
#
#	    for n in [1, 2, 3, 4]:
#	        plt.plot(t1, t1 ** n, label="n=%d" % n)
#	    leg = plt.legend(loc='best', ncol=4, mode="expand", shadow=True)
#	    leg.get_frame().set_alpha(alpha)
#
#
#	# if this fig should be saved
#	if fig_path != None:
#		plt.savefig(fig_path + fig_name +'.png')
#	
#
#
##	for i in range(figure_arrangement%10):
##		plt.subplots(,figsize=fig_size,linewidth=0.1)
#
#	return 1
