package com.af.v4.system.common.jpa.dynamic;

import com.af.v4.system.common.core.constant.HttpStatus;
import com.af.v4.system.common.core.exception.ServiceException;
import com.af.v4.system.common.core.utils.SpringUtils;
import com.af.v4.system.common.datasource.DynamicDataSource;
import com.af.v4.system.common.datasource.wrapper.AfDataSourceWrapper;
import com.af.v4.system.common.jpa.utils.SessionFactoryBeanBuilder;
import jakarta.persistence.EntityGraph;
import jakarta.persistence.EntityManager;
import jakarta.persistence.PersistenceUnitUtil;
import jakarta.persistence.SynchronizationType;
import jakarta.persistence.metamodel.Metamodel;
import org.hibernate.*;
import org.hibernate.boot.Metadata;
import org.hibernate.boot.MetadataSources;
import org.hibernate.boot.spi.SessionFactoryOptions;
import org.hibernate.cfg.Configuration;
import org.hibernate.engine.spi.FilterDefinition;
import org.hibernate.graph.RootGraph;
import org.hibernate.query.criteria.HibernateCriteriaBuilder;
import org.hibernate.relational.SchemaManager;
import org.hibernate.stat.Statistics;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.beans.factory.BeanCreationException;
import org.springframework.orm.hibernate5.LocalSessionFactoryBean;

import javax.naming.NamingException;
import javax.naming.Reference;
import java.io.Serial;
import java.sql.Connection;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.locks.ReentrantLock;

/**
 * 动态数据源Session工厂实现类
 */
public class DynamicSessionFactoryImpl implements DynamicSessionFactory {

    private static final Logger LOGGER = LoggerFactory.getLogger(DynamicSessionFactoryImpl.class);

    @Serial
    private static final long serialVersionUID = 5384069312247414885L;
    private static final Integer DEFAULT_WAIT_TIME = 5000;
    private final Map<String, LocalSessionFactoryBean> localSessionFactoryBeans;
    // 锁对象
    private final ReentrantLock LOCK = new ReentrantLock();

    public DynamicSessionFactoryImpl() {
        this.localSessionFactoryBeans = new ConcurrentHashMap<>(DynamicDataSource.getDataSourceMap().size());
    }

    @Override
    public SessionFactory getHibernateSessionFactory() {
        return getLocalSessionFactoryBean().getObject();
    }

    @Override
    public Configuration getConfiguration() {
        return getLocalSessionFactoryBean().getConfiguration();
    }

    private LocalSessionFactoryBean getLocalSessionFactoryBean(String dataSourceName) {
        // 检查 session factory 是否已经被计算并缓存
        LocalSessionFactoryBean sessionFactoryBean = localSessionFactoryBeans.get(dataSourceName);
        if (sessionFactoryBean != null) {
            return sessionFactoryBean;
        }
        // 如果不存在，进行加锁并再次检查
        LOCK.lock();
        try {
            sessionFactoryBean = localSessionFactoryBeans.get(dataSourceName);
            if (sessionFactoryBean != null) {
                return sessionFactoryBean; // 双重检查，已经有线程创建了
            }
            // 计算并缓存 session factory bean
            AfDataSourceWrapper afDataSourceWrapper = DynamicDataSource.getDataSourceMap().get(dataSourceName);
            if (afDataSourceWrapper != null) {
                try {
                    // 构建并返回 session factory
                    LocalSessionFactoryBean bean = SpringUtils.getBean(SessionFactoryBeanBuilder.class).buildSessionFactory(dataSourceName, afDataSourceWrapper);
                    localSessionFactoryBeans.put(dataSourceName, bean);
                    return bean;
                } catch (Exception e) {
                    if (e instanceof BeanCreationException) {
                        throw new ServiceException("数据源[" + dataSourceName + "]连接失败————" + e.getMessage(), HttpStatus.ERROR);
                    }
                    // 如果是默认数据源，进行重试
                    if (dataSourceName.equals(DynamicDataSource.DEFAULT_DATASOURCE_NAME)) {
                        LOGGER.error("数据源[{}]连接失败，{}毫秒后重试...", DynamicDataSource.DEFAULT_DATASOURCE_NAME, DEFAULT_WAIT_TIME, e);
                        try {
                            Thread.sleep(DEFAULT_WAIT_TIME);
                        } catch (InterruptedException ex) {
                            throw new RuntimeException(ex);
                        }
                        // 对默认数据源进行重试
                        return getLocalSessionFactoryBean(DynamicDataSource.DEFAULT_DATASOURCE_NAME);
                    } else {
                        // 对于非默认数据源，抛出异常
                        throw new ServiceException("数据源[" + dataSourceName + "]连接失败————" + e.getMessage(), HttpStatus.ERROR);
                    }
                }
            } else {
                // 如果没有对应的数据源，回退到默认数据源
                if (dataSourceName.equals(DynamicDataSource.DEFAULT_DATASOURCE_NAME)) {
                    throw new ServiceException("未指定默认数据源 : [" + dataSourceName + "]");
                }
                return getLocalSessionFactoryBean(DynamicDataSource.DEFAULT_DATASOURCE_NAME);
            }
        } finally {
            LOCK.unlock();
        }
    }


    @Override
    public LocalSessionFactoryBean getLocalSessionFactoryBean() {
        return getLocalSessionFactoryBean(DynamicDataSource.getDataSource());
    }

    @Override
    public Metadata getMetadata() {
        MetadataSources metadataSources = new MetadataSources(
                getConfiguration().getStandardServiceRegistryBuilder().build()
        );
        return metadataSources.buildMetadata();
    }

    @Override
    public SessionFactoryOptions getSessionFactoryOptions() {
        return getHibernateSessionFactory().getSessionFactoryOptions();
    }

    @Override
    public SessionBuilder withOptions() {
        return getHibernateSessionFactory().withOptions();
    }

    @Override
    public Session openSession() throws HibernateException {
        return getHibernateSessionFactory().openSession();
    }

    @Override
    public Session getCurrentSession() throws HibernateException {
        return getHibernateSessionFactory().getCurrentSession();
    }

    @Override
    public StatelessSessionBuilder withStatelessOptions() {
        return getHibernateSessionFactory().withStatelessOptions();
    }

    @Override
    public StatelessSession openStatelessSession() {
        return getHibernateSessionFactory().openStatelessSession();
    }

    @Override
    public StatelessSession openStatelessSession(Connection connection) {
        return getHibernateSessionFactory().openStatelessSession(connection);
    }

    @Override
    public Statistics getStatistics() {
        return getHibernateSessionFactory().getStatistics();
    }

    @Override
    public SchemaManager getSchemaManager() {
        return getHibernateSessionFactory().getSchemaManager();
    }

    @Override
    public void close() throws HibernateException {
        SessionFactory sessionFactory = getHibernateSessionFactory();
        if (sessionFactory != null) {
            sessionFactory.close();
        }
    }

    @Override
    public Map<String, Object> getProperties() {
        return getHibernateSessionFactory().getProperties();
    }

    @Override
    public boolean isClosed() {
        return getHibernateSessionFactory().isClosed();
    }

    @Override
    public Cache getCache() {
        return getHibernateSessionFactory().getCache();
    }

    @Override
    public PersistenceUnitUtil getPersistenceUnitUtil() {
        return getHibernateSessionFactory().getPersistenceUnitUtil();
    }

    @Override
    public void addNamedQuery(String name, jakarta.persistence.Query query) {
        getHibernateSessionFactory().addNamedQuery(name, query);
    }

    @Override
    public <T> T unwrap(Class<T> cls) {
        return getHibernateSessionFactory().unwrap(cls);
    }

    @Override
    public <T> void addNamedEntityGraph(String graphName, EntityGraph<T> entityGraph) {
        getHibernateSessionFactory().addNamedEntityGraph(graphName, entityGraph);
    }

    @Override
    public Set<String> getDefinedFilterNames() {
        return getHibernateSessionFactory().getDefinedFilterNames();
    }

    @Override
    public FilterDefinition getFilterDefinition(String filterName) throws HibernateException {
        return getHibernateSessionFactory().getFilterDefinition(filterName);
    }

    @Override
    public Set<String> getDefinedFetchProfileNames() {
        return getHibernateSessionFactory().getDefinedFetchProfileNames();
    }

    @Override
    public boolean containsFetchProfileDefinition(String name) {
        return getHibernateSessionFactory().containsFetchProfileDefinition(name);
    }

    @Override
    public Reference getReference() throws NamingException {
        return getHibernateSessionFactory().getReference();
    }

    @Override
    public <T> List<EntityGraph<? super T>> findEntityGraphsByType(Class<T> entityClass) {
        return getHibernateSessionFactory().findEntityGraphsByType(entityClass);
    }

    @Override
    public RootGraph<?> findEntityGraphByName(String name) {
        return getHibernateSessionFactory().findEntityGraphByName(name);
    }

    @Override
    public EntityManager createEntityManager() {
        return getHibernateSessionFactory().createEntityManager();
    }

    @Override
    public EntityManager createEntityManager(Map map) {
        return getHibernateSessionFactory().createEntityManager(map);
    }

    @Override
    public EntityManager createEntityManager(SynchronizationType synchronizationType) {
        return getHibernateSessionFactory().createEntityManager(synchronizationType);
    }

    @Override
    public EntityManager createEntityManager(SynchronizationType synchronizationType, Map map) {
        return getHibernateSessionFactory().createEntityManager(synchronizationType, map);
    }

    @Override
    public HibernateCriteriaBuilder getCriteriaBuilder() {
        return getHibernateSessionFactory().getCriteriaBuilder();
    }

    @Override
    public Metamodel getMetamodel() {
        return getHibernateSessionFactory().getMetamodel();
    }

    @Override
    public boolean isOpen() {
        return getHibernateSessionFactory().isOpen();
    }
}
