
package cn.quantgroup.big.stms.common.filter;

import org.apache.commons.lang3.StringUtils;
import org.jsoup.Jsoup;
import org.jsoup.nodes.Document;
import org.jsoup.safety.Whitelist;

import javax.servlet.ReadListener;
import javax.servlet.ServletInputStream;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletRequestWrapper;
import java.io.*;
import java.util.HashMap;
import java.util.Iterator;
import java.util.Map;


public class XssHttpServletRequestWrapper extends HttpServletRequestWrapper {
	
	private final String d_param_name = "content";
	private final String d_param_endsWith = "withHtml";

    private HttpServletRequest originalRequest = null;

    /**
     * 配置可以通过过滤的白名单
     */

    private static final Whitelist whitelist = Whitelist.relaxed();

    /**
     * 配置过滤化参数,不对代码进行格式化
     */
    private static final Document.OutputSettings outputSettings = new Document.OutputSettings().prettyPrint(false);

    static {
        // 增加需要过滤的其它白名单属性、标签等
        // 富文本编辑时一些样式是使用style来进行实现的
        // 比如红色字体 style="color:red;"
        // 所以需要给所有标签添加style属性
        whitelist.addAttributes(":all", "style");
    }

    public XssHttpServletRequestWrapper(HttpServletRequest request) {
        super(request);
        this.originalRequest = request;

    }

    @Override
    public ServletInputStream getInputStream() throws IOException {
    	try(BufferedReader br = new BufferedReader(new InputStreamReader(originalRequest.getInputStream()))) {
    		String result = "";
    		String line = br.readLine();
            while (line != null) {
                result += clean(line);
                line = br.readLine();
            }
            
            return new WrappedServletInputStream(new ByteArrayInputStream(result.getBytes()));
    	}catch (Exception e) {
			throw new IOException(e);
		}
    }

    /**
     * 覆盖getParameter方法，将参数名和参数值都做xss过滤。
     * 
     * 如果需要获得原始的值，则通过super.getParameterValues(name)来获取
     * 
     * getParameterNames,getParameterValues和getParameterMap也可能需要覆盖
     */
    @Override
    public String getParameter(String name) {
        if ((d_param_name.equals(name) || name.endsWith(d_param_endsWith))) {
            return super.getParameter(name);
        }

        name = clean(name);
        String value = super.getParameter(name);
        if (StringUtils.isNotBlank(value)) {
            value = clean(value);
        }

        return value;
    }

    @SuppressWarnings({ "unchecked", "rawtypes" })
    @Override
    public Map getParameterMap() {
        Map map = super.getParameterMap();
        // 返回值Map
        Map<String, String> returnMap = new HashMap<String, String>();
        Iterator entries = map.entrySet().iterator();
        Map.Entry entry = null;
        String name = "";
        String value = "";
        while (entries.hasNext()) {
            entry = (Map.Entry) entries.next();
            name = (String) entry.getKey();
            Object valueObj = entry.getValue();
            if (null == valueObj) {
                value = "";
            } else if (valueObj instanceof String[]) {
                String[] values = (String[]) valueObj;
                for (int i = 0; i < values.length; i++) {
                    value = values[i] + ",";
                }
                value = value.substring(0, value.length() - 1);
            } else {
                value = valueObj.toString();
            }
            returnMap.put(name, clean(value).trim());
        }
        return returnMap;
    }

    @Override
    public String[] getParameterValues(String name) {
        String[] arr = super.getParameterValues(name);
        if (arr != null) {
            for (int i = 0; i < arr.length; i++) {
                arr[i] = clean(arr[i]);
            }
        }
        return arr;
    }

    /**
     * 覆盖getHeader方法，将参数名和参数值都做xss过滤。
     * 
     * 如果需要获得原始的值，则通过super.getHeaders(name)来获取
     * 
     * getHeaderNames 也可能需要覆盖
     */
    @Override
    public String getHeader(String name) {

        name = clean(name);
        String value = super.getHeader(name);
        if (StringUtils.isNotBlank(value)) {
            value = clean(value);
        }
        return value;
    }

    /**
     * 127 * 获取最原始的request
     *
     * @return
     */
    public HttpServletRequest getOriginalRequestRequest() {
        return originalRequest;
    }

    /**
     * 获取最原始的request的静态方法
     *
     * @return
     */
    public static HttpServletRequest getOriginalRequestRequest(HttpServletRequest req) {
        if (req instanceof XssHttpServletRequestWrapper) {
            return ((XssHttpServletRequestWrapper) req).getOriginalRequestRequest();
        }

        return req;
    }

    private String clean(String content) {
        String result = Jsoup.clean(content, "", whitelist, outputSettings);
        return result;
    }

    private class WrappedServletInputStream extends ServletInputStream {
        private InputStream stream;

        public WrappedServletInputStream(InputStream stream) {
            this.stream = stream;
        }

        @Override
        public int read() throws IOException {
            return stream.read();
        }

        @Override
        public boolean isFinished() {
            return true;
        }

        @Override
        public boolean isReady() {
            return true;
        }

        @Override
        public void setReadListener(ReadListener readListener) {

        }
    }
}
