# coding=utf-8
'''
自动写无人分析版的月监控报告.
'''

import numpy as np
import pandas as pd
from docx import Document
from docx.shared import Cm
from docx.enum.text import WD_PARAGRAPH_ALIGNMENT
from docx.enum.table import WD_TABLE_ALIGNMENT

from datetime import datetime
import os

import warnings

warnings.filterwarnings('ignore')

class AutoReportor:
    '''
    自动写报告.
    逻辑流程:
        1. 获取原始信息.
        2. 挑选重要信息.
        3. 按照一定规则汇总信息, 写总结.
        4. 分别对每个模型进行展示.
    '''
    def __init__(self, excel_path='./model_score.xlsx', sheet_name='model',
                 is_pdf=False):
        # 创建空文档.
        self.doc = Document()

        # 标题.
        self.doc.add_heading('模型监控报告', level=0)

        # 监控模型.
        self.field_info_df = pd.read_excel(excel_path, sheet_name=sheet_name)
        self.field_name_list = self.field_info_df.field_name.tolist()
        self.field_query_list = self.field_info_df.field_query.tolist()
        self.field_query_name_dict = dict(zip(self.field_query_list, self.field_name_list))

        # 信息保存路径.
        self.auc_path = './auc/'
        self.vlm_path = './vlm/'
        self.psi_path = './psi/'

        # 读取info.
        self.auc_info_df = pd.read_csv(self.auc_path + 'info/auc_info.csv')
        self.psi_info_df = pd.read_csv(self.psi_path + 'info/psi_info.csv')
        self.vlm_info_df = pd.read_csv(self.vlm_path + 'info/vlm_info.csv')

        self.psi_info_df['name'] = self.psi_info_df.apply(lambda x: x['field_query'] + '-' + x['group_name'], axis=1)

        # 读取图片名称列表.
        self.psi_image_name_list = os.listdir(self.psi_path + 'image/')
        self.psi_image_name_list = [x for x in self.psi_image_name_list if x.endswith('.png')]

        self.vlm_image_name_list = os.listdir(self.vlm_path + 'image/image/')
        self.vlm_image_name_list = [x for x in self.vlm_image_name_list if x.endswith('.png')]
        self.vlm_image_name_over3std_list = os.listdir(self.vlm_path + 'image/over3std/')
        self.vlm_image_name_over3std_list = [x for x in self.vlm_image_name_over3std_list if x.endswith('.png')]
        self.vlm_image_name_trend_list = os.listdir(self.vlm_path + 'image/trend/')
        self.vlm_image_name_trend_list = [x for x in self.vlm_image_name_trend_list if x.endswith('.png')]

        self.is_pdf = is_pdf


    def get_image(self, image_name, type_='psi'):
        if type_ == 'psi':
            return self.psi_path + 'image/' + image_name + '.png'

        elif type_ == 'vlm':
            res_name = image_name + '-mean.png'
            if res_name in self.vlm_image_name_list:
                return self.vlm_path + 'image/image/' + res_name
            elif res_name in self.vlm_image_name_over3std_list:
                return self.vlm_path + 'image/over3std/' + res_name
            else:
                return self.vlm_path + 'image/trend/' + res_name
        else:
            print('图片种类输入错误.')

    def summary(self):

        pass

    def show_detail(self, name):
        self.doc.add_heading(name, level=1)

        # AUC
        self.doc.add_paragraph('模型区分度(AUC)', style='List Bullet')
        col_list = [x for x in self.auc_info_df.columns.tolist() if 'NUM' in x]
        recent_col = max(col_list)

        tmp_df = self.auc_info_df.loc[self.auc_info_df['model_name'] == name]
        tmp_df = tmp_df.sort_values(recent_col, ascending=False)

        tmp_auc = tmp_df[['group_name'] + [x for x in tmp_df.columns if 'AUC' in x]]
        tmp_auc.columns = [x[: -3] if 'AUC' in x else x for x in tmp_auc.columns]

        n_row = tmp_auc.shape[0]
        n_col = tmp_auc.shape[1]
        if n_row > 15:
            tmp_auc = tmp_auc.iloc[: 15]
            n_row = 15
        table = self.doc.add_table(rows=n_row + 1, cols=n_col, style='Table Grid')

        hdr_cells = table.rows[0].cells
        for i, col in enumerate(tmp_auc.columns.tolist()):
            hdr_cells[i].text = col
        for row in range(n_row):
            row_cells = table.rows[row + 1].cells
            for col in range(n_col):
                row_cells[col].text = str(tmp_auc.iloc[row, col])

        # PSI
        self.doc.add_paragraph('模型稳定性(PSI)', style='List Bullet')

        tmp_df = self.psi_info_df.loc[self.psi_info_df['field_query'] == name]
        tmp_df = tmp_df.loc[(self.psi_info_df['field_query'] == name) &
                            (self.psi_info_df['is_abnormal'])]

        n_row = tmp_df.shape[0]
        if n_row > 5:
            tmp_df = tmp_df.iloc[: 5]
            n_row = 5

        paragraph = self.doc.add_paragraph()
        paragraph.alignment = WD_PARAGRAPH_ALIGNMENT.CENTER
        run = paragraph.add_run('')

        if not any(['全样本' in x for x in tmp_df['group_name'].tolist()]):
            run.add_picture(self.get_image(name + '-全样本'), width=Cm(10))
        for i in range(n_row):
            run.add_picture(self.get_image(tmp_df.iloc[i, -1]), width=Cm(10))

        # VLM
        self.doc.add_paragraph('模型均值变化(VLM)', style='List Bullet')

        tmp_df = self.vlm_info_df.loc[self.vlm_info_df['model_name'] == name]
        over3std_df = tmp_df.loc[tmp_df['over_3std']]
        trend_df = tmp_df.loc[tmp_df['h']].sort_values('p')


        self.doc.add_paragraph('整体', style='List Bullet 2')
        paragraph = self.doc.add_paragraph()
        paragraph.alignment = WD_PARAGRAPH_ALIGNMENT.CENTER
        run = paragraph.add_run('')
        run.add_picture(self.get_image(name + '-全样本'), width=Cm(10))

        self.doc.add_paragraph('波动', style='List Bullet 2')


        n_row = over3std_df.shape[0]
        if n_row > 5:
            over3std_df = over3std_df.iloc[: 5]
            n_row = 5

        paragraph = self.doc.add_paragraph()
        paragraph.alignment = WD_PARAGRAPH_ALIGNMENT.CENTER
        run = paragraph.add_run('')
        for i in range(n_row):
            run.add_picture(self.get_image(over3std_df.iloc[i, 3]), width=Cm(10))

        self.doc.add_paragraph('趋势', style='List Bullet 2')
        n_row = trend_df.shape[0]
        if n_row > 5:
            trend_df = trend_df.iloc[: 5]
            n_row = 5

        paragraph = self.doc.add_paragraph()
        paragraph.alignment = WD_PARAGRAPH_ALIGNMENT.CENTER
        run = paragraph.add_run('')

        for i in range(n_row):
            run.add_picture(self.get_image(trend_df.iloc[i, 3]), width=Cm(10))

    def format(self):

        pass

    @staticmethod
    def doc2pdf(doc_path):
        from win32com.client import constants, gencache
        pdfPath = doc_path[: -4] + 'pdf'
        word = gencache.EnsureDispatch('Word.Application')
        doc = word.Documents.Open(doc_path, ReadOnly=1)
        doc.ExportAsFixedFormat(pdfPath,
                                constants.wdExportFormatPDF,
                                Item=constants.wdExportDocumentWithMarkup,
                                CreateBookmarks=constants.wdExportCreateHeadingBookmarks)
        word.Quit(constants.wdDoNotSaveChanges)

    def run(self):
        # 根据规则写总结.



        # 写每个模型的信息.
        for name in self.field_name_list:
            print('正在处理 ' + name)
            self.show_detail(name)

        # 统一格式.
        self.format()

        # 保存.
        today = datetime.today()
        cur_year = today.year
        cur_month = today.month
        self.doc.save('MM_report_%d%d01.docx' % (cur_year, cur_month))

        if self.is_pdf:
            cur_path = os.getcwd()
            self.doc2pdf(cur_path + '/MM_report_%d%d01.docx' % (cur_year, cur_month))

        print('报告生成完毕, 保存成功.')

