"""
Created on Thu Apr 18 11:32:06 2019

@author: wangjiahua
"""


import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns 


plt.rcParams['font.sans-serif'] = ['SimHei']
plt.rcParams['axes.unicode_minus'] = False
plt.rcParams['savefig.dpi'] = 226 #图片像素 
plt.rcParams['figure.dpi'] = 200 #分辨率



def plot_curve_singleCurve(dataset, x_label = None, y_label = None,table_tab = None,
                           save_path = None, figure_arrangement = 11, fig_size = (4,3),
                           fig_title='General Plot', fig_name = 'untitled',
                           fig_path = None):

    
    col = dataset.columns
    index = pd.Series(dataset.index.sort_values()).astype(str)
    plt.figure(figsize=fig_size)
    metric = figure_arrangement // 10 * figure_arrangement % 10
    
    for i in range(int(np.ceil(len(col) // metric))):
        
        cols = col[i * metric:]
        for fig_ith in range(len(cols)):
            axs = plt.subplot(figure_arrangement * 10 + 1 + fig_ith)
            axs.plot(index,dataset.loc[cols[fig_ith]])
            axs.set_title(cols[fig_ith],fontsize = 7)
            plt.xticks(fontsize = 5)
            plt.yticks(fontsize = 5)
            plt.grid()

            if x_label != None:
                axs.set_xlabel(x_label, fontsize = 5)
                if y_label != None:        
                    axs.set_ylabel(y_label, fontsize = 5)
        plt.tight_layout()
        plt.show()
    return 1
    



#fig,axs = plt.subplots(1,1,figsize=(16,9),linewidth=0.1)



#
#for fig_ith in range(len(df.columns)):
#    axs = plt.subplot(figure_arrangement * 10 + 1 + fig_ith)
#    axs.plot(df.index,df.iloc[fig_ith])
#    axs.set_title(col[])
#plt.tight_layout()

def plot_curve_multiCurve(dataset, x_label = None, y_label = None,table_tab = None,
                           save_path = None, figure_arrangement = 11, fig_size = (4,3),
                           fig_title='General Plot', fig_name = 'untitled',
                           fig_path = None):

    col = dataset.columns
    index = pd.Series(dataset.index.sort_values()).astype(str)
    plt.figure(figsize=fig_size)
    #metric = figure_arrangement // 10 * figure_arrangement % 10
       
        #cols = col[i * metric:]
    axs = plt.subplot(111)
    for fig_ith in range(len(col)):            
        axs.plot(index,dataset.loc[col[fig_ith]],label=col[fig_ith])
    axs.set_title(col[fig_ith],fontsize = 7)
    plt.xticks(fontsize = 5)
    plt.yticks(fontsize = 5)
    plt.grid()

    if x_label != None:
        axs.set_xlabel(x_label, fontsize = 5)
    if y_label != None:        
        axs.set_ylabel(y_label, fontsize = 5)
    plt.legend()
    plt.tight_layout()
    plt.show()
    return 1
    
'''

'''
def plot_curve_mingle():
    return 1
    
    
def density_chart(dataset,title):
    for col in dataset.columns:
        sns.kdeplot(dataset.loc[:,col],label = col)
    plt.title(title)
    plt.show()


def uniVarChart():
    return 1       
        
        
        
#        
#	    alpha = 0.98 / 4 * fig_ith + 0.01
#	    ax.set_title('%.3f' % alpha)
#	    t1 = np.arange(0.0, 1.0, 0.01)
#
#
#	    for n in [1, 2, 3, 4]:
#	        plt.plot(t1, t1 ** n, label="n=%d" % n)
#	    leg = plt.legend(loc='best', ncol=4, mode="expand", shadow=True)
#	    leg.get_frame().set_alpha(alpha)
#
#
#	# if this fig should be saved
#	if fig_path != None:
#		plt.savefig(fig_path + fig_name +'.png')
#	
#
#
##	for i in range(figure_arrangement%10):
##		plt.subplots(,figsize=fig_size,linewidth=0.1)
#
#	return 1