package com.aote;

import org.apache.commons.lang.StringUtils;
import org.apache.log4j.Logger;
import org.json.JSONArray;
import org.json.JSONObject;

import java.util.HashMap;
import java.util.Map;
import java.util.regex.Pattern;
import java.util.concurrent.ConcurrentHashMap;

/**
 * sql关键字过滤(防sql注入)
 * 增强版：包含盲注防护
 */
public class ParamFilter {
    static final Logger log = Logger.getLogger(ParamFilter.class);

    //sql非法字符
    private static final Map<String,Object> sqlkeywords = new HashMap<>();

    // 盲注攻击模式
    private static final Pattern[] BLIND_INJECTION_PATTERNS = {
            // 时间延迟攻击
            Pattern.compile("(?i)\\b(waitfor\\s+delay|sleep\\s*\\(|benchmark\\s*\\(|pg_sleep\\s*\\()"),
            Pattern.compile("(?i)\\b(dbms_pipe\\.receive_message|dbms_lock\\.sleep)"),

            // 布尔盲注模式
            Pattern.compile("(?i)\\b(and|or)\\s+\\d+\\s*[=<>]\\s*\\d+"),
            Pattern.compile("(?i)\\b(and|or)\\s+['\"]?\\w+['\"]?\\s*[=<>]\\s*['\"]?\\w+['\"]?"),

            // 信息提取函数
            Pattern.compile("(?i)\\b(ascii|char|substring|substr|mid|left|right|len|length)\\s*\\("),
            Pattern.compile("(?i)\\b(hex|unhex|ord|conv|bin)\\s*\\("),

            // 条件判断和错误注入
            Pattern.compile("(?i)\\b(if\\s*\\(|case\\s+when|iif\\s*\\()"),
            Pattern.compile("(?i)\\b(cast\\s*\\(|convert\\s*\\(|extractvalue\\s*\\(|updatexml\\s*\\()"),

            // 系统函数和信息收集
            Pattern.compile("(?i)\\b(user\\s*\\(|version\\s*\\(|database\\s*\\(|schema\\s*\\(|@@version|@@user)"),
            Pattern.compile("(?i)\\b(information_schema|sys\\.|msdb\\.|master\\.)"),

            // 编码绕过检测
            Pattern.compile("(?i)\\b0x[0-9a-f]+"),
            Pattern.compile("(?i)\\bchar\\s*\\(\\s*\\d+"),

            // Union注入的变种
            Pattern.compile("(?i)\\bunion\\s+(all\\s+)?select"),
            Pattern.compile("(?i)\\bunion\\s*\\(\\s*select"),

            // 注释符号的变种
            Pattern.compile("/\\*.*?\\*/"),
            Pattern.compile("(?i)#.*$"),
            Pattern.compile("(?i)--.*$"),

            // 逻辑操作符组合
//        Pattern.compile("(?i)\\b(and|or)\\s+(true|false|null)"),
//        Pattern.compile("(?i)\\b(and|or)\\s+not\\s+"),

            // 特殊字符组合（可能的盲注载荷）
            Pattern.compile("['\"]\\s*[+]\\s*['\"]"),
            Pattern.compile("\\|\\|"),
//        Pattern.compile("&&"),

            // 数据库特定的盲注技术
            Pattern.compile("(?i)\\b(load_file|into\\s+outfile|into\\s+dumpfile)"),
            Pattern.compile("(?i)\\b(exec\\s*\\(|execute\\s*\\(|sp_executesql)")
    };

    // 请求频率限制 - 防止盲注爆破
    private static final Map<String, RequestTracker> requestTrackers = new ConcurrentHashMap<>();
    private static final int MAX_REQUESTS_PER_MINUTE = 60;
    private static final long REQUEST_WINDOW_MS = 60000; // 1分钟

    static class RequestTracker {
        private long lastRequestTime;
        private int requestCount;
        private boolean blocked;
        private long blockUntil;

        RequestTracker() {
            this.lastRequestTime = System.currentTimeMillis();
            this.requestCount = 1;
            this.blocked = false;
        }

        boolean checkAndUpdate() {
            long now = System.currentTimeMillis();

            // 检查是否还在阻止期内
            if (blocked && now < blockUntil) {
                return false;
            }

            // 重置阻止状态
            if (blocked && now >= blockUntil) {
                blocked = false;
                requestCount = 0;
            }

            // 检查时间窗口
            if (now - lastRequestTime > REQUEST_WINDOW_MS) {
                requestCount = 1;
                lastRequestTime = now;
                return true;
            }

            requestCount++;
            lastRequestTime = now;

            // 检查频率限制
            if (requestCount > MAX_REQUESTS_PER_MINUTE) {
                blocked = true;
                blockUntil = now + REQUEST_WINDOW_MS * 5; // 阻止5分钟
                log.warn("IP请求频率过高被阻止. 请求次数: " + requestCount);
                return false;
            }

            return true;
        }
    }

    static{
        sqlkeywords.put("select","select");
        sqlkeywords.put("insert","insert");
        sqlkeywords.put("update","update");
        sqlkeywords.put("delete","delete");
        sqlkeywords.put("from","from");
        sqlkeywords.put("drop","drop");
        sqlkeywords.put("count","count");
        sqlkeywords.put("table","table");
        sqlkeywords.put("truncate","truncate");
        sqlkeywords.put("declare","declare");
        sqlkeywords.put("asc(","asc(");
        sqlkeywords.put("mid(","mid(");
        sqlkeywords.put("char(","char(");
        sqlkeywords.put("where","where");
        sqlkeywords.put("master","master");
        sqlkeywords.put("netlocalgroup","netlocalgroup");
        sqlkeywords.put("administrators","administrators");
        sqlkeywords.put("xp_cmdshell","xp_cmdshell");

        sqlkeywords.put("exec","exec");
        sqlkeywords.put("execute","execute");
        sqlkeywords.put("xp_","xp_");
        sqlkeywords.put("sp_","sp_");
        sqlkeywords.put("0x","0x");
        sqlkeywords.put(";",";");
        sqlkeywords.put("or","or");
        sqlkeywords.put("\"t_user\"","\"t_user\"");
        sqlkeywords.put("--","--");
        sqlkeywords.put("#","#");
        sqlkeywords.put("union","union");
        sqlkeywords.put("/","/");
        sqlkeywords.put("//","//");

        // 增加盲注相关关键字
        sqlkeywords.put("waitfor","waitfor");
        sqlkeywords.put("delay","delay");
        sqlkeywords.put("sleep","sleep");
        sqlkeywords.put("benchmark","benchmark");
        sqlkeywords.put("ascii","ascii");
        sqlkeywords.put("substring","substring");
        sqlkeywords.put("substr","substr");
        sqlkeywords.put("length","length");
        sqlkeywords.put("if(","if(");
        sqlkeywords.put("case","case");
        sqlkeywords.put("when","when");
        sqlkeywords.put("extractvalue","extractvalue");
        sqlkeywords.put("updatexml","updatexml");
        sqlkeywords.put("cast(","cast(");
        sqlkeywords.put("convert(","convert(");
        sqlkeywords.put("user()","user()");
        sqlkeywords.put("version()","version()");
        sqlkeywords.put("database()","database()");
        sqlkeywords.put("@@version","@@version");
        sqlkeywords.put("@@user","@@user");
        sqlkeywords.put("information_schema","information_schema");
        sqlkeywords.put("load_file","load_file");
        sqlkeywords.put("into","into");
        sqlkeywords.put("outfile","outfile");
        sqlkeywords.put("dumpfile","dumpfile");
//            "select","insert","update","delete","from","drop","count","table","truncate","declare","asc(","mid(","char(",
//            "where","master","netlocalgroup","administrators","xp_cmdshell","net user","exec","execute","xp_","sp_","0x",
//            ";","or","\"t_user\"","--","#","union","/","//"
    }

    /**
     * 增强的SQL注入检测 - 带IP频率限制
     * @param map 待检测map,只检查value
     * @param clientIp 客户端IP地址
     */
    public static void checkSqlMapEnhanced(Map<String, Object> map, String clientIp){
        // 首先检查频率限制
        if (!checkRateLimit(clientIp)) {
            throw new RuntimeException("请求频率过高，已被暂时阻止");
        }

        for(Object m : map.values()){
            checkSqlStrEnhanced(String.valueOf(m), clientIp);
        }
    }

    /**
     * SQL注入检测 - 兼容旧版本
     * @param map 待检测map,只检查value
     */
    public static void checkSqlMap(Map<String, Object> map){
        for(Object m : map.values()){
            checkSqlStrEnhanced(String.valueOf(m), "unknown");
        }
    }
    /**
     * 增强的JSON SQL注入检测 - 带IP频率限制
     * @param jsonStr 待检测json,只检查value
     * @param clientIp 客户端IP地址
     */
    public static void checkSqlJsonStrEnhanced(String jsonStr, String clientIp) {
        log.debug("待检数据："+jsonStr);

        // 首先检查频率限制
        if (!checkRateLimit(clientIp)) {
            throw new RuntimeException("请求频率过高，已被暂时阻止");
        }

        JSONObject jsonObject = new JSONObject(jsonStr);
        checkSqlJsonEnhanced(jsonObject, clientIp);
    }

    /**
     * SQL注入检测 - 兼容旧版本
     * @param jsonStr 待检测json,只检查value
     */
    public static void checkSqlJsonStr(String jsonStr) {
        log.debug("待检数据："+jsonStr);
        JSONObject jsonObject = new JSONObject(jsonStr);
        checkSqlJsonEnhanced(jsonObject, "unknown");
    }

    /**
     * 增强的JSON对象SQL注入检测
     * @param jsonObject 待检测json,只检查value
     * @param clientIp 客户端IP地址
     */
    public static void checkSqlJsonEnhanced(JSONObject jsonObject, String clientIp) {
        for (String k : jsonObject.keySet()) {
            Object v = jsonObject.get(k);
            checkSqlJsonVEnhanced(v, clientIp);
        }
    }

    /**
     * SQL注入检测 - 兼容旧版本
     * @param jsonObject 待检测json,只检查value
     */
    public static void checkSqlJson(JSONObject jsonObject) {
        checkSqlJsonEnhanced(jsonObject, "unknown");
    }

    /**
     * 增强的JSON值检测
     */
    public static void checkSqlJsonVEnhanced(Object v, String clientIp){
        if (v instanceof JSONArray) {
            JSONArray ja = new JSONArray(String.valueOf(v));
            for(int i=0;i<ja.length();i++ ){
                Object value = ja.get(i);
                checkSqlJsonVEnhanced(value, clientIp);
            }
        } else if (v instanceof JSONObject) {
            checkSqlJsonEnhanced((JSONObject) v, clientIp);
        } else if( v instanceof String){
            String vStr = String.valueOf(v);
            if(vStr.startsWith("{")){
                checkSqlJsonVEnhanced(new JSONObject(vStr), clientIp);
            } else if(vStr.startsWith("[")){
                checkSqlJsonVEnhanced(new JSONArray(vStr), clientIp);
            } else {
                checkSqlStrEnhanced(vStr, clientIp);
            }
        }
    }

    public static void checkSqlJsonV(Object v){
        checkSqlJsonVEnhanced(v, "unknown");
    }

    /**
     * 增强的SQL注入检测
     * @param str 待验证的字符串
     * @param clientIp 客户端IP地址
     */
    public static void checkSqlStrEnhanced(String str, String clientIp){
        if(StringUtils.isBlank(str)){
            return;
        }

        long startTime = System.currentTimeMillis();

        try {
            // 1. 检查盲注攻击模式
            checkBlindInjectionPatterns(str, clientIp);

            // 2. 原有的关键字检测（增强版）
            checkSqlKeywords(str, clientIp);

            // 3. 检查编码绕过尝试
            checkEncodingBypass(str, clientIp);

            // 4. 检查SQL语法结构
            checkSqlSyntax(str, clientIp);

        } finally {
            long endTime = System.currentTimeMillis();
            long processingTime = endTime - startTime;

            // 检测异常处理时间（可能的盲注时间攻击）
            if (processingTime > 100) { // 超过100ms认为异常
                log.warn("检测处理时间异常: " + processingTime + "ms, IP: " + clientIp + ", 输入: " + str);
            }
        }
    }

    /**
     * SQL注入检测 - 兼容旧版本
     * @param str  待验证的字符串
     */
    public static void checkSqlStr(String str){
        checkSqlStrEnhanced(str, "unknown");
    }

    /**
     * 检查盲注攻击模式
     */
    private static void checkBlindInjectionPatterns(String str, String clientIp) {
        String lowerStr = str.toLowerCase();

        for (Pattern pattern : BLIND_INJECTION_PATTERNS) {
            if (pattern.matcher(str).find() || pattern.matcher(lowerStr).find()) {
                log.error("检测到盲注攻击模式, IP: " + clientIp + ", 模式: " + pattern.pattern() + ", 输入: " + str);
                throw new RuntimeException("参数包含非法字符");
            }
        }
    }

    /**
     * 增强的关键字检测
     */
    private static void checkSqlKeywords(String str, String clientIp) {
        String lowerStr = str.toLowerCase().trim();

        // 移除多余空格和特殊字符进行检测
        String normalizedStr = lowerStr.replaceAll("\\s+", " ")
                .replaceAll("/\\*.*?\\*/", "")
                .replaceAll("--.*$", "")
                .replaceAll("#.*$", "");

        String[] strs = normalizedStr.split("[\\s,;()]+");
        for(String trimStr: strs){
            if(StringUtils.isNotBlank(trimStr) && sqlkeywords.containsKey(trimStr)){
                log.error("参数包含SQL关键字: " + trimStr + ", IP: " + clientIp + ", 完整输入: " + str);
                throw new RuntimeException("参数包含非法字符");
            }
        }
    }

    /**
     * 检查编码绕过尝试
     */
    private static void checkEncodingBypass(String str, String clientIp) {
        // 检查十六进制编码
        if (str.matches(".*\\b0x[0-9a-fA-F]+.*")) {
            log.error("检测到十六进制编码尝试, IP: " + clientIp + ", 输入: " + str);
            throw new RuntimeException("参数包含非法字符");
        }

        // 检查Unicode编码
        if (str.matches(".*\\\\u[0-9a-fA-F]{4}.*")) {
            log.error("检测到Unicode编码尝试, IP: " + clientIp + ", 输入: " + str);
            throw new RuntimeException("参数包含非法字符");
        }

        // 检查URL编码的SQL关键字
        String decoded = str.replaceAll("%20", " ")
                .replaceAll("%27", "'")
                .replaceAll("%22", "\"")
                .replaceAll("%3B", ";")
                .replaceAll("%2D%2D", "--")
                .replaceAll("%2F%2A", "/*");

        if (!decoded.equals(str)) {
            checkSqlKeywords(decoded, clientIp);
        }
    }

    /**
     * 检查SQL语法结构
     */
    private static void checkSqlSyntax(String str, String clientIp) {
        String lowerStr = str.toLowerCase();

        // 检查多语句注入
        if (lowerStr.contains(";") && (lowerStr.contains("select") || lowerStr.contains("insert")
                || lowerStr.contains("update") || lowerStr.contains("delete"))) {
            log.error("检测到多语句注入尝试, IP: " + clientIp + ", 输入: " + str);
            throw new RuntimeException("参数包含非法字符");
        }

        // 检查注释符后的SQL语句
        if ((lowerStr.contains("--") || lowerStr.contains("/*")) &&
                (lowerStr.contains("select") || lowerStr.contains("union"))) {
            log.error("检测到注释符SQL注入尝试, IP: " + clientIp + ", 输入: " + str);
            throw new RuntimeException("参数包含非法字符");
        }

        // 检查堆叠查询
        if (lowerStr.matches(".*\\b(select|insert|update|delete).*\\b(select|insert|update|delete).*")) {
            log.error("检测到堆叠查询尝试, IP: " + clientIp + ", 输入: " + str);
            throw new RuntimeException("参数包含非法字符");
        }
    }

    /**
     * 频率限制检查
     */
    private static boolean checkRateLimit(String clientIp) {
        if (StringUtils.isBlank(clientIp) || "unknown".equals(clientIp)) {
            return true; // 对未知IP不做限制
        }

        RequestTracker tracker = requestTrackers.computeIfAbsent(clientIp, k -> new RequestTracker());
        boolean allowed = tracker.checkAndUpdate();

        // 清理过期的追踪器（避免内存泄漏）
        if (requestTrackers.size() > 10000) {
            cleanupExpiredTrackers();
        }

        return allowed;
    }

    /**
     * 清理过期的请求追踪器
     */
    private static void cleanupExpiredTrackers() {
        long now = System.currentTimeMillis();
        requestTrackers.entrySet().removeIf(entry -> {
            RequestTracker tracker = entry.getValue();
            return now - tracker.lastRequestTime > REQUEST_WINDOW_MS * 6; // 6分钟后清理
        });
    }

}
