package com.aote.sql;

import com.aote.ThreadResource;
import com.aote.entity.EntityServer;
import org.apache.log4j.Logger;
import org.dom4j.Document;
import org.dom4j.Element;
import org.dom4j.io.SAXReader;
import org.springframework.stereotype.Component;

import java.io.InputStream;
import java.util.HashMap;
import java.util.Iterator;
import java.util.Map;

/**
 * sql映射器
 */
@Component
public class SqlMapper {

	static Logger LOGGER = Logger.getLogger(SqlMapper.class);

	private static Map<String, Map<String, String>> map;

	static {
		//初始化即加载
		loadMap();
	}

	public static String getSql(String alias)
	{
		String newAlias;
		switch (EntityServer.findDialect()){
			case EntityServer.DIALECT_ORACLE:
				newAlias = alias + "_oracle";
				break;
			case EntityServer.DIALECT_MYSQL:
				newAlias = alias + "_mysql";
				break;
			default:
				newAlias = alias;
		}
		if(!map.containsKey(newAlias)) {
			if(map.containsKey(alias)){
				newAlias = alias;
			} else {
				return null;
			}
		}
		String result = getSqlByUser(newAlias);
		if(result != null) {
			return result;
		}
		return map.get(newAlias).get("path");
	}

	public static String getAttr(String alias, String name) {
		if (map.containsKey(alias)) {
			return map.get(alias).get(name);
		} else {
			return null;
		}
	}

	public static Map<String, Map<String, String>> getMap() {
		return SqlMapper.map;
	}

	public static String getSqlByUser(String alias){
		//获取登陆用户用户名
		String dir = ThreadResource.ComponentDir.get();
		if (dir == null) {
			return null;
		}
		//查看对应目录下是否存在相关sql
		String path = "/" + dir + "/sqls/" +alias + ".sql";
		if (SqlMapper.class.getResource(path) != null) {
			return path;
		}
		return null;
	}

	// 加载映射文件
	@SuppressWarnings("rawtypes")
	private static void loadMap() {
		map = new HashMap<>();

		SAXReader reader = new SAXReader();
		InputStream module = SqlMapper.class.getClassLoader().getResourceAsStream("module.xml");
		try {
			// 有模块定义文件，先处理模块定义文件
			if (module != null) {
				parseModule(module);
			}
			// 加载Sql定义文件
			InputStream input = SqlMapper.class.getClassLoader().getResourceAsStream("sql.xml");
			// 有Sql定义文件，从Sql定义文件中获取Sql，存放到全局map中
			if (input != null) {
				loadOneMap(input, reader, null, null);
			}
		} catch (Exception ex) {
			throw new RuntimeException(ex);
		}
	}

	// 从Sql定义文件中获取Sql定义内容，存放到全局map中
	// name, parent: 用于加载模块中Sql时，指明到前缀
	private static void loadOneMap(InputStream input, SAXReader reader, String name, String parent) throws Exception {
		Document document = reader.read(input);
		Element root = document.getRootElement();
		for (Iterator it = root.elementIterator("sql"); it.hasNext();) {
			Element element = (Element) it.next();
			String aliasLogic = element.attribute("alias").getValue();
			String pathLogic = element.attribute("path").getValue();
			String mobile = element.attributeValue("mobile");
			// 获取Sql书写语言，包括：func、expression，没有值，就为null，代表exprssion
			String language = element.attributeValue("language");

			// 验证别名是否重复(只验证module模块)
			if (name != null && map.containsKey(aliasLogic)) {
				throw new Exception("别名" + aliasLogic + "已存在");
			}

			String path = getPath(name, parent, pathLogic);

			// 设置业务逻辑属性
			Map<String, String> logic = new HashMap<>();
			logic.put("alias", aliasLogic);
			logic.put("path", path);
			logic.put("mobile", mobile);
			// 设置语言属性，空代表expression
			logic.put("language", language);
			map.put(aliasLogic, logic);
		}
	}

	// 根据模块名等内容获取Sql路径
	// name, parent: 加载模块中Sql时，需要等模块名，父模块名信息，如果没有，name为空
	// path：从配置文件中读取到的路径
	private static String getPath(String name, String parent, String path) {
		// 不是模块中的内容
		if (name == null) {
			return "sqls/" + path;
		} else if (parent == null) {
			return name + "/sqls/" + path;
		} else {
			// 如果有parent，把parent添加到前面
			return parent + "/" + name + "/sqls/" + path;
		}
	}

	private static void parseModule (InputStream input) throws Exception{
		SAXReader reader = new SAXReader();
		Document document = reader.read(input);
		Element root = document.getRootElement();
		for (Iterator it = root.elementIterator("module"); it.hasNext();) {
			Element elm = (Element) it.next();
			String name = elm.attribute("name").getValue();
			putMap(name, reader, null);
			for (Iterator children = elm.elementIterator("module"); children.hasNext();) {
				Element childElm = (Element) children.next();
				String childname = childElm.attribute("name").getValue();
				putMap(childname, reader, name);
			}
		}
	}

	// 处理一个模块内容
	private static void putMap(String name, SAXReader reader, String parent) throws Exception {
		String str = (parent == null ? name + "/sql.xml" : parent + "/" + name + "/sql.xml");
		// 读取模块文件到内容
		InputStream moduleSql = SqlMapper.class.getClassLoader().getResourceAsStream(str);
		if (moduleSql != null) {
			// 处理一个sql.xml文件内容
			loadOneMap(moduleSql, reader, name, parent);
		} else {
			throw new RuntimeException("注意！！！找不到文件：" + str);
		}
	}
}
