package com.aote.sql;

import com.aote.ThreadResource;
import com.aote.entity.EntityServer;
import org.apache.log4j.Logger;
import org.dom4j.Document;
import org.dom4j.Element;
import org.dom4j.io.SAXReader;
import org.springframework.stereotype.Component;

import java.io.InputStream;
import java.util.HashMap;
import java.util.Iterator;
import java.util.Map;

/**
 * sql映射器
 */
@Component
public class SqlMapper {

	static Logger LOGGER = Logger.getLogger(SqlMapper.class);

	private static Map<String, Map<String, String>> map;

	static {
		//初始化即加载
		loadMap();
	}

	public static String getSql(String alias)
	{
		String newAlias;
		switch (EntityServer.findDialect()){
			case EntityServer.DIALECT_ORACLE:
				newAlias = alias + "_oracle";
				break;
			case EntityServer.DIALECT_MYSQL:
				newAlias = alias + "_mysql";
				break;
			default:
				newAlias = alias;
		}
		if(!map.containsKey(newAlias)) {
			if(map.containsKey(alias)){
				newAlias = alias;
			} else {
				return null;
			}
		}
		String result = getSqlByUser(newAlias);
		if(result != null) {
			return result;
		}
		return map.get(newAlias).get("path");
	}

	public static Map<String, Map<String, String>> getMap() {
		return SqlMapper.map;
	}

	public static String getSqlByUser(String alias){
		//获取登陆用户用户名
		String dir = ThreadResource.ComponentDir.get();
		if (dir == null) {
			return null;
		}
		//查看对应目录下是否存在相关sql
		String path = "/" + dir + "/sqls/" +alias + ".sql";
		if (SqlMapper.class.getResource(path) != null) {
			return path;
		}
		return null;
	}

	@SuppressWarnings("rawtypes")
	private static void loadMap() {
		map = new HashMap<>();

		SAXReader reader = new SAXReader();
		InputStream module = SqlMapper.class.getClassLoader().getResourceAsStream("module.xml");
		try {
			if (module != null) {
				parseModule(module);
			}
			InputStream input = SqlMapper.class.getClassLoader().getResourceAsStream("sql.xml");
			if (input != null) {
				Document document = reader.read(input);
				Element root = document.getRootElement();
				for (Iterator<Element> item = root.elementIterator("sql"); item.hasNext();) {
					Element element = item.next();
					setSqlRes(element, null, null);
				}
			}
		} catch (Exception ex) {
			ex.printStackTrace();
		}
	}

	private static void parseModule(InputStream input) throws Exception {
		SAXReader reader = new SAXReader();
		Document document = reader.read(input);
		Element root = document.getRootElement();
		for (Iterator<Element> it = root.elementIterator("module"); it.hasNext();) {
			Element elm = it.next();
			String name = elm.attribute("name").getValue();
			String path = elm.attributeValue("path");
			putMap(name, reader, path);
		}
	}

	private static void putMap(String name, SAXReader reader, String parent) throws Exception {
		String str = (parent == null ? name + "/sql.xml" : parent + "/" + name + "/sql.xml");
		InputStream moduleSql = SqlMapper.class.getClassLoader().getResourceAsStream(str);
		if (moduleSql != null) {
			Document docSql = reader.read(moduleSql);
			Element rootSql = docSql.getRootElement();
			for (Iterator<Element> item = rootSql.elementIterator("sql"); item.hasNext(); ) {
				Element element = item.next();
				setSqlRes(element, name, parent);
			}
		}
	}

	private static void setSqlRes(Element element, String moduleName, String moduleParent) throws Exception {
		String aliasSql = element.attribute("alias").getValue();
		String pathSql = element.attribute("path").getValue();
		String mobile = element.attributeValue("mobile");

		// 验证别名是否重复
		if (map.containsKey(aliasSql)){
			throw new Exception("Sql别名"+ aliasSql + "已存在");
		}

		String path;

		if(moduleName == null){
			path = "sqls/" + pathSql;
		} else {
			path = (moduleParent == null ? moduleName + "/sqls/" : moduleParent + "/" + moduleName + "/sqls/") + pathSql;
		}
		// 设置SQL属性
		Map<String, String> sql = new HashMap<>();
		sql.put("alias", aliasSql);
		sql.put("path", path);
		sql.put("mobile", mobile);
		map.put(aliasSql, sql);
	}
}
