# Standard Library
import os
from contextlib import redirect_stdout, redirect_stderr
from typing import Callable, Optional, Tuple, List

# Third-party: Clustering algorithms and Multiple Initializations
from joblib import Parallel, delayed # for parallel processing
from kmodes.kmodes import KModes     # for K-Modes clustering
from stepmix.stepmix import StepMix  # for BMM clustering
import numpy as np

# Third-party: Quantum Chemistry
import pyscf
import pyscf.mcscf
from pyscf import fci

# Third-party: Qiskit Addon for SQD
from qiskit_addon_sqd.configuration_recovery import _p_flip_0_to_1, _p_flip_1_to_0  
from qiskit_addon_sqd.counts import bit_array_to_arrays, bitstring_matrix_to_integers 
from qiskit_addon_sqd.fermion import _unique_with_order_preserved 
from qiskit_addon_sqd.subsampling import postselect_by_hamming_right_and_left 


def define_n2(
    distance: float
    ) -> Tuple[np.ndarray, np.ndarray, float]:
    """
    Computes Hamiltonian components (hcore, eri) and nuclear repulsion energy for N2.
    
    This function performs a Restricted Hartree-Fock (RHF) calculation using PySCF
    and extracts the 1-body and 2-body Hamiltonian integrals for a defined active space.

    Args:
        distance (float): Interatomic distance between the two Nitrogen (N) atoms (Unit: Angstrom).

    Returns:
        Tuple[np.ndarray, np.ndarray, float]:
            - hcore (np.ndarray): 1-electron integrals (h1cas) in the active space.
            - eri (np.ndarray): 2-electron repulsion integrals (h2cas) in the active space.
            - nuclear_repulsion_energy (float): Nuclear-nuclear repulsion energy.

    Note:
        - Basis set: cc-pvdz
        - Symmetry: Dooh
        - Active Space: Defined by excluding 2 frozen core orbitals (n_frozen=2).
    """
    
    # Configuration
    open_shell = False
    spin_sq = 0 
    mol = pyscf.gto.Mole() 
    
    # Generate N2 molecule
    mol.build(
        atom=[["N", (0, 0, 0)], ["N", (distance, 0, 0)]], 
        basis="cc-pvdz", 
        symmetry="Dooh"
    )

    # Define active space 
    n_frozen = 2 # Excluding the first 2 frozen core orbitals
    active_space = range(n_frozen, mol.nao_nr()) # mol.nao_nr(): total number of atomic orbitals

    # Run Restricted Hartree-Fock (RHF)
    scf = pyscf.scf.RHF(mol).run() 
    
    # Calculate number of orbitals and electrons in the active space
    num_orbitals = len(active_space) 
    n_electrons = int(sum(scf.mo_occ[active_space])) 
    
    # Assign electrons per spin (Handles potential Open Shell configurations)
    num_elec_a = (n_electrons + mol.spin) // 2 # Number of alpha (up) spins
    num_elec_b = (n_electrons - mol.spin) // 2 # Number of beta (down) spins
    
    # Initialize CASCI object and sort Molecular Orbitals (MO)
    cas = pyscf.mcscf.CASCI(scf, num_orbitals, (num_elec_a, num_elec_b))
    mo = cas.sort_mo(active_space, base=0) 
    
    # Extract 1-body integrals and nuclear repulsion energy
    hcore, nuclear_repulsion_energy = cas.get_h1cas(mo)

    # Extract and restore 2-body integrals (Electron Repulsion Integrals)
    eri = pyscf.ao2mo.restore(1, cas.get_h2cas(mo), num_orbitals)

    return hcore, eri, nuclear_repulsion_energy


    

def diagonalize_fermionic_hamiltonian_with_clustering(
    bitarray: Tuple[np.ndarray, np.ndarray] | np.ndarray, 
    n_orb: int,              
    n_elec: int,             
    n_clusters: int,         
    n_batch: int,            
    threshold: float,        
    samples_per_batch: int,  
    max_iterations: int,     
    hcore: np.ndarray,
    eri: np.ndarray,
    nuclear_repulsion_energy: float,
    max_dim: Optional[int] = None,     
    energy_tol: float = 1e-8,          
    occupancies_tol: float = 1e-5,     
    clustering_func: Optional[Callable] = None, 
    rng: Optional[np.random.Generator] = None 
    ) -> Tuple[tuple, list[float], list[np.ndarray], list[np.ndarray], np.ndarray]:
    """
    Computes the ground state energy of a fermionic Hamiltonian using clustering and basis refinement algorithms.

    Args:
        bitarray (Tuple[np.ndarray, np.ndarray] | np.ndarray): Raw measurement samples from quantum computer (or sampler).
        n_orb (int): Number of spatial orbitals.
        n_elec (int): Number of electrons per spin (assuming alpha == beta).
        n_clusters (int): Number of clusters for partitioning the Hilbert space.
        n_batch (int): Degree of parallelism for basis set exploration.
        threshold (float): Coefficient threshold for filtering important basis functions.
        samples_per_batch (int): Number of new configurations to explore per iteration.
        max_iterations (int): Maximum number of refinement cycles.
        max_dim (int, optional): Maximum basis dimension per batch to control computational cost. Defaults to None.
        energy_tol (float, optional): Convergence tolerance for energy. Defaults to 1e-8.
        occupancies_tol (float, optional): Convergence tolerance for orbital occupancy (n-vector). Defaults to 1e-5.
        clustering_func (Callable, optional): External clustering function. Defaults to K-Modes if None.

    Returns:
        Tuple[tuple, List[float], List[np.ndarray], List[np.ndarray], np.ndarray]:
            - best_result (tuple): The optimal result tuple (ci_strs, coef, energy, membership).
            - energy_history (List[float]): History of estimated ground state energies.
            - basis_history (List[np.ndarray]): History of basis sets used in each iteration.
            - n_vec_history (List[np.ndarray]): History of orbital occupancy profiles (n-vectors).
            - cluster_weight (np.ndarray): Final weights of the clusters.
    """
    
    # =========================================================================
    # SECTION 1: Initialization & Defaults
    # =========================================================================

    if rng is None:
        rng = np.random.default_rng(42)

    if clustering_func is None:
        clustering_func = assign_clusters_kmodes  # Default: K-Modes algorithm

    energy_history = [] 
    basis_history = [] 
    n_vec_history = []
    

    # =========================================================================
    # SECTION 2: Preprocessing & Clustering
    # =========================================================================

    # 1. Unify input data format (Tuple -> Array)
    if type(bitarray) == tuple:
        input_data = bitarray
    else:
        input_data = bit_array_to_arrays(bitarray)

    # 2. Aggregate spin blocks and generate statistical distribution
    ci_strs, ci_probs = pool_spin_blocks(n_orb, input_data) 

    # 3. Perform Clustering (Partitioning the Hilbert Space)
    clusters, cluster_probs, cluster_weight = clustering_func(ci_strs, ci_probs, n_clusters)

    # 4. Assign measured samples to the nearest existing clusters
    correct_samples_in_clusters, correct_samples_probs_in_clusters = sort_correct_samples_by_existing_clusters(
        input_data, n_elec, clusters
        )

    # 5. Calculate initial orbital occupancy profiles (n-vectors) per cluster
    raw_n_vectors = calc_raw_n_vectors(clusters, cluster_probs, n_orb)
    
    # 6. Distribute sampling count proportional to cluster weights
    samples_per_cluster = []
    for i in range(len(clusters)):
        numb_of_samples = np.ceil(cluster_weight[i] * samples_per_batch)
        samples_per_cluster.append(numb_of_samples)
        
    samples_per_cluster = tuple(samples_per_cluster)


    # =========================================================================
    # SECTION 3: Iteration 1 (Initial Basis Exploration)
    # =========================================================================
    
    # 1. Generate initial basis sets based on initial n-vectors and measured data
    initial_batches = make_initial_batches(
        n_batch,
        n_elec, 
        n_orb,
        raw_n_vectors, 
        samples_per_cluster, 
        correct_samples_in_clusters, 
        correct_samples_probs_in_clusters
    )
    
    initial_results = []
    print("Iteration 1")
    
    # 2. Diagonalization for each generated initial batch
    for i, initial_batch in enumerate(initial_batches):

        # 2-1. Batch Data Preparation
        batch_bool = initial_batch[0]  # basis set
        batch_probs = initial_batch[1] # probabilities
        membership = initial_batch[2]  # membership matrix
        
        batch_ints = bitstring_matrix_to_integers(batch_bool) # Convert bitstring matrix to integer list
        
        # 2-2. Importance Sampling (Sort by Probability & Truncate)
        sorted_idx = np.argsort(batch_probs)[::-1]
        batch_ints = batch_ints[sorted_idx] # Synchronize batch_ints
        membership = membership[sorted_idx] # Synchronize membership matrix

        unique_strs = _unique_with_order_preserved(batch_ints)[:max_dim]
        membership = membership[:max_dim] 
        
        # 2-3. Final Sorting for PySCF Compatibility
        final_sort_idx = np.argsort(unique_strs)
        unique_strs = unique_strs[final_sort_idx] # Synchronize unique_strs
        membership = membership[final_sort_idx]   # Synchronize membership matrix
        
        current_ci_strs = (unique_strs, unique_strs) # Generate tuple for restricted CI (Alpha=Beta)
        basis_history.append(unique_strs) # Record basis set history

        # 2-4. Solve SCI Problem with Cluster Membership
        result = solve_sci_cluster(
            current_ci_strs, 
            hcore, 
            eri, 
            n_orb,    
            (n_elec, n_elec), 
            nuclear_repulsion_energy,
            membership  
        )
        
        initial_results.append(result) # Record initial result

        print(f"        Batch {i+1}:")
        print(f"        Energy: {result[2]:.6f}")
        print(f"        Subspace dimension: {len(unique_strs)**2}")

    
    # 3. Set and record initial best result
    best_result = min(initial_results, key=lambda x: x[2])
    energy_history.append(best_result[2])
    n_vec = raw_n_vectors
    n_vec_history.append(n_vec)
    print(f"        Estimated Ground State Energy: {best_result[2]:.8f}")
    
    # 4. Initialize convergence check variables
    prev_energy = best_result[2] + 10.0         # Large initial difference
    prev_n_vec = np.array(raw_n_vectors).copy() # Initial n-vector for convergence check
    

    # =========================================================================
    # SECTION 4: Main Loop (Generative Refinement)
    # =========================================================================

    for j in range(max_iterations-1):
        
        # 1. Update Guidance (n-vector) based on current best cluster weights
        n_vec = cal_n_vecs(best_result, prev_n_vec, n_elec, n_orb, n_clusters)
        n_vec_history.append(n_vec)
        

        # 2. Check Convergence: Energy & Occupancy Profile
        energy_diff = abs(best_result[2] - prev_energy)
        occupancy_diff = np.max(np.abs(n_vec - prev_n_vec)) # Compare current n_vec with previous n_vec using max absolute difference
        
        print(f"Iteration {j+2}: dE = {energy_diff:.2e}, dN_max = {occupancy_diff:.2e}")
        
        if energy_diff < energy_tol and occupancy_diff < occupancies_tol:
            print(f"Convergence achieved at Iteration {j+2}!")
            break

        prev_energy = best_result[2] # Backup current best energy for next iteration's convergence check
        

        # 3. Refine Clusters (Generate Cleaned Samples using updated guidance)
        refined_clusters = []
        for k in range(n_clusters):
            refined_clusters.append(
                refine_a_cluster(clusters[k], cluster_probs[k], n_vec[k], n_elec, n_orb, rng)
            )
        

        # 4. Generate Next Generation Batches (Integrate carry-over(important basis) and refined samples)
        batches = make_batches(
            n_batch,
            best_result,
            refined_clusters,
            samples_per_cluster,
            n_orb,
            threshold,
            max_dim,
            base_seed=None
        )
    
        results = []
        ci_strings_history = [] # Not used in this version
        print(f"Iteration {j+2}")
    

        # 5. Batch Diagonalization Loop
        for i, batch in enumerate(batches):
            batch_bool, batch_probs, membership = batch
            batch_ints = bitstring_matrix_to_integers(batch_bool)
            
            # 5-1. Probability Sort & Truncate
            sorted_idx = np.argsort(batch_probs)[::-1]
            batch_ints = batch_ints[sorted_idx] # Synchronize batch_ints
            membership = membership[sorted_idx] # Synchronize membership matrix
            
            unique_strs = _unique_with_order_preserved(batch_ints)[:max_dim]
            membership = membership[:max_dim] 
            
            # 5-2. Final Index Sort for PySCF Compatibility
            final_sort_idx = np.argsort(unique_strs)
            unique_strs = unique_strs[final_sort_idx]
            membership = membership[final_sort_idx] 
            
            current_ci_strs = (unique_strs, unique_strs)
            basis_history.append(unique_strs) 
            
            # 5-3. Solve SCI Problem with Cluster Membership
            result = solve_sci_cluster(
                current_ci_strs, 
                hcore, 
                eri, 
                n_orb,    
                (n_elec, n_elec), 
                nuclear_repulsion_energy,
                membership 
            )
            
            results.append(result)
            print(f"        Batch {i+1}:")
            print(f"        Energy: {result[2]:.6f}")
            print(f"        Subspace dimension: {len(unique_strs)**2}")

        
        # 6. Update Global Best if improvement found
        current_gen_best = min(results, key=lambda x: x[2])
        if current_gen_best[2] < best_result[2]:
            best_result = current_gen_best
            print(f"        Estimated Ground State Energy: {best_result[2]:.8f}")

        energy_history.append(best_result[2])
        prev_n_vec = np.array(n_vec).copy() # Backup current n-vector for next iteration's convergence check
        
    return best_result, energy_history, basis_history, n_vec_history, cluster_weight

    



def pool_spin_blocks(
    n_orb, 
    input_data
    ) -> Tuple[np.ndarray, np.ndarray]:

    """
    Aggregates alpha and beta spin components into a single statistical pool of spin-orbitals.
    
    Args:
        n_orb (int): Number of spatial orbitals (half the total number of qubits).
        input_data (Tuple[np.ndarray, np.ndarray]): A tuple containing:
            - bitstring_matrix (np.ndarray): Boolean array of shape (N_samples, 2*n_orb).
              Structure: [Beta Strings | Alpha Strings].
            - probabilities (np.ndarray): 1D array of shape (N_samples,) representing 
              the probability of each full bitstring.

    Returns:
        Tuple[np.ndarray, np.ndarray]:
            - unique_blocks (np.ndarray): Array of unique spin-blocks with shape (N_unique, n_orb).
            - normalized_probs (np.ndarray): Normalized probability distribution for the unique blocks 
              (Sums to 1.0).
    """

    # 1. Data Unpacking & Splitting
    bitstring_matrix, probs = input_data
    
    # Split the full system bitstrings into Beta and Alpha components.
    # Assumes structure: [Beta (0 ~ n_orb-1) | Alpha (n_orb ~ 2*n_orb-1)]
    beta_strs = bitstring_matrix[:, :n_orb]
    alpha_strs = bitstring_matrix[:, n_orb:]
    
    # 2. Pooling (Stacking)
    combined_strs = np.vstack([beta_strs, alpha_strs])
    combined_probs = np.concatenate([probs, probs])
    
    # 3. Aggregation (Unique Blocks & Summation)
    unique_strs, inverse_str_indices = np.unique(combined_strs, axis=0, return_inverse=True)
    unique_probs = np.zeros(len(unique_strs))
    np.add.at(unique_probs, inverse_str_indices, combined_probs) # Accumulate probabilities for identical blocks.
    
    # 4. Normalization
    normalized_probs = unique_probs / np.sum(unique_probs)
    
    return unique_strs, normalized_probs




def assign_clusters_kmodes(
    bitstrings: np.ndarray, 
    probabilities: np.ndarray, 
    k: int, 
    random_state: int = 42, 
    n_init: int = 100
    ) -> Tuple[list[np.ndarray], list[np.ndarray], np.ndarray]:

    """
    Partitions bitstrings into k clusters using the K-Modes algorithm with probability weighting.
    
    Args:
        bitstrings (np.ndarray): Input data matrix of shape (N_samples, n_features). 
                                 Can be Boolean or Integer array representing bitstrings.
        probabilities (np.ndarray): Probability weights associated with each sample of shape (N_samples,).
        k (int): The number of clusters to form.
        random_state (int, optional): Seed for reproducibility. Defaults to 42.
        n_init (int, optional): Number of times the k-modes algorithm will be run with different centroid seeds. Defaults to 100.

    Returns:
        Tuple[List[np.ndarray], List[np.ndarray], np.ndarray]:
            - separated_clusters (List[np.ndarray]): List containing arrays of bitstrings for each cluster.
            - separated_probs (List[np.ndarray]): List containing arrays of probabilities corresponding to each cluster.
            - cluster_prob_sums (np.ndarray): Array of shape (k,) containing the total probability weight (sum) of each cluster.
    """
    
    # 1. Initialization
    km = KModes(
        n_clusters=k, 
        init='Huang',  # init='Huang': Method for initialization suitable for categorical data
        n_init=n_init, 
        verbose=0, 
        random_state=random_state, 
        n_jobs=-1
    )
    

    # 2. Fitting & Prediction
    # - Fit the model and predict cluster index for each sample.
    # - sample_weight is crucial here: High-probability samples drive the centroid formation.
    labels = km.fit_predict(bitstrings, sample_weight=probabilities)
    
    # 3. Data Separation & Aggregation
    separated_clusters = []   # Bitstrings in each cluster
    separated_probs = []      # Probabilities in each cluster
    cluster_prob_sums = []    # Weights (sum of probabilities) for each cluster
    
    for i in range(k):

        # Create a boolean mask for the current cluster
        mask = (labels == i)
        
        # Extract data and probabilities for the current cluster
        cluster_data = bitstrings[mask]
        cluster_individual_probs = probabilities[mask]
        
        # Append to results
        separated_clusters.append(cluster_data)
        separated_probs.append(cluster_individual_probs)
        
        # Calculate total weight (probability sum) of the cluster
        p_sum = np.sum(cluster_individual_probs)
        cluster_prob_sums.append(p_sum)
        
    return separated_clusters, separated_probs, np.array(cluster_prob_sums)




def assign_clusters_bmm(
    bitstrings, 
    probabilities, 
    k, 
    random_state=42, 
    n_init=100
    ) -> Tuple[list[np.ndarray], list[np.ndarray], np.ndarray]:
    """
    Partitions bitstrings into k clusters using a Bernoulli Mixture Model (BMM).

    Args:
        bitstrings (np.ndarray): Input binary data of shape (N_samples, n_features).
        probabilities (np.ndarray): Sample weights (probabilities) of shape (N_samples,).
        k (int): Number of mixture components (clusters).
        random_state (int, optional): Base seed for reproducibility. Defaults to 42.
        n_init (int, optional): Number of parallel initializations. Defaults to 100.
                                (Effectively uses n_init CPU cores if available).

    Returns:
        Tuple[List[np.ndarray], List[np.ndarray], np.ndarray]:
            - separated_clusters (List[np.ndarray]): List of bitstring arrays for each cluster.
            - separated_probs (List[np.ndarray]): List of probability arrays for each cluster.
            - cluster_prob_sums (np.ndarray): Total probability mass (weight) of each cluster.
    """
    
    # 1. Define Worker Function (Single Initialization)
    def fit_single_stepmix(seed: int) -> Tuple[StepMix, float]:
        """
        Fits a single StepMix instance with a specific seed.
        """
        with open(os.devnull, 'w') as fnull:
            # Suppress both standard output and error streams
            with redirect_stdout(fnull), redirect_stderr(fnull):
                
                model = StepMix(
                    n_components=k, 
                    measurement="binary", 
                    verbose=0, 
                    random_state=seed, 
                    n_init=1 # n_init=1: Each worker performs exactly one EM run.
                )
                
                # Fit the model (progress_bar disabled for stability)
                model.fit(bitstrings, sample_weight=probabilities)
                
        return model, model.lower_bound_

    
    # 2. Parallel Execution
    seeds = [random_state + i for i in range(n_init)] # Generate distinct seeds for each parallel worker
    
    # Execute fit_single_stepmix in parallel using all available cores
    results = Parallel(n_jobs=-1)(
        delayed(fit_single_stepmix)(s) for s in seeds
    )

    # 3. Model Selection
    # - Select the model with the highest Evidence Lower Bound (Log-Likelihood approximation)
    best_model, _ = max(results, key=lambda x: x[1])
    
    # 4. Labeling & Separation
    # - Predict cluster labels for all samples using the best model
    labels = best_model.predict(bitstrings)
    
    separated_clusters = []
    separated_probs = []
    cluster_prob_sums = []
    
    for i in range(k):
        # Create mask for current cluster index
        mask = (labels == i)

        # Extract data and probabilities
        cluster_data = bitstrings[mask]
        cluster_individual_probs = probabilities[mask]
        
        separated_clusters.append(cluster_data)
        separated_probs.append(cluster_individual_probs)

        # Calculate cluster weight (Total probability mass)
        p_sum = np.sum(cluster_individual_probs)
        cluster_prob_sums.append(p_sum)
        
    return separated_clusters, separated_probs, np.array(cluster_prob_sums)
    


def sort_correct_samples_by_existing_clusters(
    input_data: Tuple[np.ndarray, np.ndarray],
    n_elec_per_spin: int, 
    clusters: list[np.ndarray]
    ) -> Tuple[tuple[np.ndarray, ...], tuple[np.ndarray, ...]]:
    """
    Filters valid physical samples and maps them to existing clusters, computing conditional probabilities.

    This function performs three main tasks:
    1. Post-selection: Filters out samples that violate particle number symmetry (Hamming weight check).
    2. Pooling: Splits bitstrings into Alpha/Beta spin blocks and aggregates duplicates.
    3. Classification: Assigns the valid spin blocks to the pre-defined clusters and normalizes their probabilities.

    Args:
        input_data (Tuple[np.ndarray, np.ndarray]): A tuple containing:
            - bitstring_matrix (np.ndarray): Boolean array of shape (N_samples, 2*n_orb).
            - probabilities (np.ndarray): 1D array of probabilities matching the samples.
        n_elec_per_spin (int): Target number of electrons per spin (used for post-selection).
        clusters (List[np.ndarray]): List of arrays, where each array contains the unique spin blocks 
                                     defining a specific cluster in the Hilbert space.

    Returns:
        Tuple[Tuple[np.ndarray, ...], Tuple[np.ndarray, ...]]:
            - classified_samples (Tuple): Tuple of arrays, where each array contains samples belonging to a cluster.
            - conditional_probs (Tuple): Tuple of arrays containing the conditional probability P(x | x in Cluster) 
              for each sample.
    """
    
    # 1. Setup
    n_orb = input_data[0].shape[1] // 2 # the number of spatial orbitals (half of total qubits)


    # 2. Post-selection: Filter samples based on particle number conservation
    correct_samples_combined, correct_probs_combined = postselect_by_hamming_right_and_left(
        input_data[0],  # bitstrings
        input_data[1],  # probabilities
        hamming_right=n_elec_per_spin, 
        hamming_left=n_elec_per_spin
    )


    # 3. Pooling (Spin Separation & Aggregation)
    correct_samples, correct_probs = pool_spin_blocks(
        n_orb, 
        (correct_samples_combined, correct_probs_combined)
    )
    
    n_clusters = len(clusters)
    

    # 4. Build Lookup Table (Hashing for Fast Access)
    # Create a hash map for O(1) cluster assignment.
    # Key: Byte representation of the numpy array (spin block).
    # Value: Cluster ID (Index).
    block_to_cluster_map = {}

    for cluster_id, cluster_data in enumerate(clusters):
        for block in cluster_data:
            block_to_cluster_map[block.tobytes()] = cluster_id


    # 5. Classification (Mapping Samples to Clusters)
    # Initialize containers for classified data
    sorted_samples_list = [[] for _ in range(n_clusters)]
    sorted_probs_list = [[] for _ in range(n_clusters)]

    # Assign each valid unique block to its corresponding cluster.
    # Assumption: 'unique_valid_blocks' is a subset of the union of all 'clusters'.
    for sample, prob in zip(correct_samples, correct_probs):
        sample_bytes = sample.tobytes()

        # Retrieve cluster ID from the map
        c_id = block_to_cluster_map[sample_bytes]
        sorted_samples_list[c_id].append(sample)
        sorted_probs_list[c_id].append(prob)


    # 6. Normalization (Conditional Probabilities)
    final_samples_tuple = []
    final_probs_tuple = []

    for c_id in range(n_clusters):
        # Convert lists to numpy arrays
        s_arr = np.array(sorted_samples_list[c_id])
        p_arr = np.array(sorted_probs_list[c_id])

        if len(s_arr) > 0:
            # Compute Conditional Probability: P(x | x in Cluster) = P(x) / Sum(P in Cluster)
            cond_p_arr = p_arr / np.sum(p_arr)
        else:
            # Handle empty clusters gracefully
            cond_p_arr = np.array([])
            # Ensure correct shape (0, n_orb) even if empty
            if len(correct_samples) > 0:
                 s_arr = s_arr.reshape(0, n_orb)
        
        final_samples_tuple.append(s_arr)
        final_probs_tuple.append(cond_p_arr)

    return tuple(final_samples_tuple), tuple(final_probs_tuple)




def calc_raw_n_vectors(
    clusters: list[np.ndarray], 
    cluster_probs: list[np.ndarray], 
    n_orb: int
    ) -> tuple[np.ndarray, ...]:
    """
    Computes the initial raw n-vectors (weighted orbital occupancy profiles) for each cluster.

    Args:
        clusters (List[np.ndarray]): List of bitstring arrays representing clusters.
                                     Each element has shape (N_k, n_orb).
        cluster_probs (List[np.ndarray]): List of probability arrays corresponding to the clusters.
                                          Each element has shape (N_k,).
        n_orb (int): Number of spatial orbitals.

    Returns:
        Tuple[np.ndarray, ...]: A tuple of raw n-vectors for each cluster.
                                Each n-vector is a 1D float array of shape (n_orb,).
    """
    
    raw_n_vectors = []
    
    # Calculation Loop
    for bitstrings, probs in zip(clusters, cluster_probs):
        
        # Compute Weighted Sum of Occupancies
        # - bitstrings: (N_samples, n_orb) boolean/int array
        # - probs: (N_samples,) float array
        weighted_sum = np.dot(probs, bitstrings.astype(float))
        
        # Store the raw weighted sum (n-vector)
        # - Note: Normalization is intentionally skipped here to preserve relative
        raw_n_vectors.append(weighted_sum)
        
    return tuple(raw_n_vectors)



def initial_subsample_cluster(
    n_elec: int, 
    n_orb: int,
    raw_n_vectors: tuple[np.ndarray, ...], 
    samples_per_cluster: tuple[int, ...], 
    correct_samples: tuple[np.ndarray, ...], 
    correct_probs: tuple[np.ndarray, ...], 
    seed: Optional[int] = None
    ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
    """
    Performs stratified sampling per cluster and generates synthetic samples if necessary.
    Returns a deduplicated basis set and a membership matrix tracking the cluster origin of each basis.

    Args:
        n_elec (int): Number of electrons per spin (Hamming weight constraint).
        n_orb (int): Number of spatial orbitals (Bitstring length).
        raw_n_vectors (Tuple[np.ndarray, ...]): Tuple of raw n-vectors (occupancy profiles) for each cluster.
        samples_per_cluster (Tuple[int, ...]): Target number of samples to extract from each cluster.
        correct_samples (Tuple[np.ndarray, ...]): Tuple containing arrays of valid existing samples per cluster.
        correct_probs (Tuple[np.ndarray, ...]): Tuple containing probability arrays for existing samples.
        seed (int, optional): Random seed for reproducibility. Defaults to None.

    Returns:
        Tuple[np.ndarray, np.ndarray, np.ndarray]:
            - unique_batch (np.ndarray): De-duplicated array of basis bitstrings. Shape: (N_unique, n_orb).
            - unique_probs (np.ndarray): Aggregated probability weights for each unique basis. Shape: (N_unique,).
            - membership_matrix (np.ndarray): Boolean matrix indicating cluster lineage. 
              Shape: (N_unique, n_clusters). Element [i, k] is True if basis 'i' originated from cluster 'k'.
    
    """
    
    # 1. Setup & Initialization
    all_sampled_strs = []
    all_sampled_probs = []
    all_sampled_cluster_ids = [] # Tracks cluster origin for each sample
    total_target_count = sum(samples_per_cluster)
    
    rng = np.random.default_rng(seed)
    n_clusters = len(raw_n_vectors)

    # 2. Cluster-wise Sampling & Generation
    for k in range(n_clusters):
        target_count = int(samples_per_cluster[k])
        raw_n_vec = raw_n_vectors[k]
        real_samples = correct_samples[k]
        real_probs = correct_probs[k]
        
        num_real = len(real_samples)
        
        # Calculate global weight of this cluster relative to the total target count
        cluster_global_weight = target_count / total_target_count

        # 1. Real Sample(양자 측정 데이터) 추출
        # 2-A. Sample from Real Data (Quantum Measurements)
        if num_real >= target_count:
            p_choice = real_probs / np.sum(real_probs)
            indices = rng.choice(num_real, size=target_count, replace=False, p=p_choice)
            cluster_k_strs = real_samples[indices]
            cluster_k_probs = real_probs[indices]
        
        # 2-B. Synthesize missing samples if real data is insufficient
        else:
            cluster_k_strs_list = []
            cluster_k_probs_list = []
            
            if num_real > 0:
                cluster_k_strs_list.append(real_samples)
                cluster_k_probs_list.append(real_probs)
            
            deficit = target_count - num_real
            prob_dist = raw_n_vec / np.sum(raw_n_vec)
            
            gen_samples = []
            gen_probs = []
            hashes = {s.tobytes() for s in (real_samples if num_real > 0 else [])}
            
            attempts = 0
            while len(gen_samples) < deficit and attempts < deficit * 20:
                attempts += 1
                occ_idx = rng.choice(n_orb, size=n_elec, replace=False, p=prob_dist)
                sample = np.zeros(n_orb, dtype=bool)
                sample[occ_idx] = True
                
                s_hash = sample.tobytes()
                if s_hash not in hashes:
                    hashes.add(s_hash)
                    gen_samples.append(sample)
    
                    # Calculate importance score as the product of probabilities for occupied orbitals.
                    # 공정한 비교는 아니지만, 값이 매우 작을거고, 실제 샘플들이 우선되고 나머지 끼리 경쟁한다는 의미에서 괜찮을듯
                    gen_probs.append(np.prod(prob_dist[occ_idx]))
            
            if gen_samples:
                cluster_k_strs_list.append(np.array(gen_samples))
                cluster_k_probs_list.append(np.array(gen_probs))
            
            cluster_k_strs = np.vstack(cluster_k_strs_list)
            cluster_k_probs = np.concatenate(cluster_k_probs_list)
        # 2-C. Normalize cluster samples and apply global weight
        all_sampled_strs.append(cluster_k_strs)
        
        cluster_k_probs = (cluster_k_probs / np.sum(cluster_k_probs)) * cluster_global_weight
        all_sampled_probs.append(cluster_k_probs)
        all_sampled_cluster_ids.append(np.full(len(cluster_k_strs), k))

    # 3. Aggregation 
    merged_strs = np.vstack(all_sampled_strs)
    merged_probs = np.concatenate(all_sampled_probs)
    merged_ids = np.concatenate(all_sampled_cluster_ids)

    # 4. De-duplication & Probability Aggregation
    unique_batch, unique_indices, inverse_indices = np.unique(
        merged_strs, axis=0, return_index=True, return_inverse=True
    )
    
    unique_probs = np.zeros(len(unique_batch))
    np.add.at(unique_probs, inverse_indices, merged_probs)

    # 5. Membership Matrix Construction
    # - Track which clusters contributed to each unique basis element
    membership_matrix = np.zeros((len(unique_batch), n_clusters), dtype=bool)
    
    # Map original samples to unique basis locations using inverse indices
    for original_idx, unique_idx in enumerate(inverse_indices):
        cluster_id = merged_ids[original_idx]
        membership_matrix[unique_idx, cluster_id] = True
        
    return unique_batch, unique_probs, membership_matrix



def make_initial_batches(
    n_batch: int,
    n_elec: int, 
    n_orb: int,
    raw_n_vectors: tuple[np.ndarray, ...],
    samples_per_cluster: tuple[int, ...],
    correct_samples: tuple[np.ndarray, ...],
    correct_probs: tuple[np.ndarray, ...],
    base_seed=None
    ) -> list[tuple[np.ndarray, np.ndarray, np.ndarray]]:
    """
    Generates a list of batches containing basis sets and membership information by iteratively calling `initial_subsample_cluster`.

    Args:
        n_batch (int): Total number of batches to generate.
        n_elec (int): Number of electrons per spin.
        n_orb (int): Number of spatial orbitals.
        raw_n_vectors (Tuple[np.ndarray, ...]): Raw n-vectors for each cluster.
        samples_per_cluster (Tuple[int, ...]): Target sample count per cluster.
        correct_samples (Tuple[np.ndarray, ...]): Valid existing samples per cluster.
        correct_probs (Tuple[np.ndarray, ...]): Probabilities for existing samples.
        base_seed (Optional[int]): Base random seed. If provided, seeds are incremented by iteration index.

    Returns:
        List[Tuple[np.ndarray, np.ndarray, np.ndarray]]: A list of tuples, where each tuple contains:
            - batch (np.ndarray): The generated unique basis set.
            - probs (np.ndarray): The associated probability weights.
            - membership (np.ndarray): The cluster membership matrix.
    """
    
    # 1. Initialization
    all_batches = []
    
    # 2. Batch Generation Loop
    for i in range(n_batch):
        # 2-A. Seed Management
        current_seed = base_seed + i if base_seed is not None else None
        
        # 2-B. Subsampling Execution
        batch, probs, membership = initial_subsample_cluster(
            n_elec, 
            n_orb,
            raw_n_vectors, 
            samples_per_cluster, 
            correct_samples, 
            correct_probs, 
            seed=current_seed
        )
        
        # 3. Store Result
        all_batches.append((batch, probs, membership))
            
    return all_batches



def solve_sci_cluster(
    ci_strings: Tuple[np.ndarray, np.ndarray], 
    one_body_tensor: np.ndarray, 
    two_body_tensor: np.ndarray, 
    norb: int, 
    nelec: Tuple[int, int], 
    nuclear_repulsion_energy: float,
    membership_matrix: np.ndarray,
    spin_sq: float = 0.0
    ) -> Tuple[Tuple[np.ndarray, np.ndarray], np.ndarray, float, np.ndarray]:
    """
    Diagonalizes the Hamiltonian within the given basis subspace and returns the wavefunction coefficients
    along with cluster membership metadata.

    Args:
        ci_strings (Tuple[np.ndarray, np.ndarray]): Tuple of (alpha_strs, beta_strs). 
                                                    Each element is a list/array of determinants.
        one_body_tensor (np.ndarray): One-electron integral (core Hamiltonian) matrix.
        two_body_tensor (np.ndarray): Two-electron integral (ERI) 4D tensor.
        norb (int): Number of spatial orbitals in the active space.
        nelec (Tuple[int, int]): Tuple of (n_alpha, n_beta) electrons.
        nuclear_repulsion_energy (float): Constant energy term for nuclear repulsion and frozen core.
        membership_matrix (np.ndarray): Boolean matrix of shape (N_basis, n_clusters).
                                        Element [i, k] is True if basis 'i' originated from cluster 'k'.
        spin_sq (float): Target S^2 value (Default is 0.0 for Singlet).

    Returns:
        Tuple[Tuple[np.ndarray, np.ndarray], np.ndarray, float, np.ndarray]:
            - ci_strings: The input basis strings tuple.
            - coefficients (np.ndarray): Wavefunction coefficient matrix (N_alpha, N_beta) obtained from diagonalization.
            - total_energy (float): Total energy (Electronic Energy + Nuclear Repulsion).
            - membership_matrix: The input membership matrix, synchronized with the basis.
    """
    
    # 1. Selected CI Solver Setup & Spin Constraint
    myci = fci.selected_ci.SelectedCI()
    if spin_sq is not None:
        myci = fci.addons.fix_spin_(myci, ss=spin_sq)

    # 2. Hamiltonian Diagonalization (Fixed Subspace)
    # kernel_fixed_space returns (energy, coefficients)
    _, coef = fci.selected_ci.kernel_fixed_space(
        myci, 
        one_body_tensor, 
        two_body_tensor, 
        norb, 
        nelec, 
        ci_strs=ci_strings,
    )

    # 3. Energy Calculation via RDMs
    # Compute 1-RDM and 2-RDM using the obtained coefficients for stable energy evaluation
    dm1 = myci.make_rdm1(coef, norb, nelec)
    dm2 = myci.make_rdm2(coef, norb, nelec)
    
    # Contract RDMs with integrals to get electronic energy
    e_elec = np.einsum("pr,pr->", dm1, one_body_tensor) + 0.5 * np.einsum(
        "prqs,prqs->", dm2, two_body_tensor
    )
    
    # 4. Total Energy Aggregation
    total_energy = e_elec + nuclear_repulsion_energy
    
    # 5. Return Results
    return (ci_strings, coef, total_energy, membership_matrix)



def cal_n_vecs(
    best_result: Tuple[Tuple[np.ndarray, np.ndarray], np.ndarray, float, np.ndarray], 
    old_n_vectors: Tuple[np.ndarray, ...], 
    n_elec: int, 
    n_orb: int, 
    n_clusters: int
    ) -> Tuple[np.ndarray, ...]:
    """
    Calculates new n-vectors (occupancy profiles) for each cluster based on the optimized wavefunction.
    If a cluster has zero contribution in the current iteration, its previous n-vector is retained.

    Args:
        best_result (Tuple[Tuple[np.ndarray, np.ndarray], np.ndarray, float, np.ndarray]): 
            A tuple containing:
            - ci_strings: Tuple of (alpha_strs, beta_strs) integer arrays.
            - coefficients: Wavefunction coefficient matrix of shape (N_alpha, N_beta).
            - energy: Total energy of the state.
            - membership_matrix: Boolean matrix of shape (N_basis, n_clusters).
        old_n_vectors (Tuple[np.ndarray, ...]): Tuple of n-vectors from the previous iteration.
        n_elec (int): Total number of electrons.
        n_orb (int): Number of spatial orbitals.
        n_clusters (int): Number of clusters.

    Returns:
        Tuple[np.ndarray, ...]: A tuple of updated n-vectors, where each n-vector is a float array of shape (n_orb,).
    """
    
    # 1. Data Extraction
    ci_strs, coef, _, membership = best_result
    alpha_strs = ci_strs[0] 
    beta_strs = ci_strs[1]  
    # beta_strs are not explicitly converted to bits in this logic, assuming symmetry or specific mapping
    
    # 2. Compute Probability Contributions (|c|^2)
    probs_matrix = np.abs(coef)**2
    
    # Sum probabilities over beta indices (axis 1) to get alpha determinant weights
    alpha_contribution = np.sum(probs_matrix, axis=1) 

    # Sum probabilities over alpha indices (axis 0) to get beta determinant weights
    beta_contribution = np.sum(probs_matrix, axis=0) 
    
    # 3. Integer to Bitstring Conversion
    def ints_to_bits_corrected(ints, n_orb):
        return np.array([
            [(ints[i] >> p) & 1 for p in range(n_orb - 1, -1, -1)] 
            for i in range(len(ints))
        ], dtype=float)
    
    # Convert alpha integers to bitstring matrix 
    bit_mat = ints_to_bits_corrected(alpha_strs, n_orb) 

    new_n_vecs = []
    
    # 4. Update Each Cluster
    for k in range(n_clusters):

        # Extract the membership mask for cluster k
        mask = membership[:, k]
        
        # Check if the cluster contributes to the current wavefunction
        if not np.any(mask):
            # Carry-over the previous n-vector if no contribution exists
            new_n_vecs.append(old_n_vectors[k])
            continue
            
        # Calculate new occupancy profile if contribution exists.
        string_weights = alpha_contribution[mask] + beta_contribution[mask] # @@@@@@@@@@@@@@@@
        bits_in_cluster = bit_mat[mask]
        
        # Compute weighted sum of bitstrings
        n_vec_k = np.dot(string_weights, bits_in_cluster)
        
        # 5. Normalization
        n_vec_k = (n_vec_k / np.sum(n_vec_k)) * n_elec
        new_n_vecs.append(n_vec_k)
        
    return tuple(new_n_vecs)




def _bipartite_bitstring_correcting_cluster(
    bitstring: np.ndarray, 
    n_vec: np.ndarray, 
    n_elec: int, 
    n_orb: int, 
    rng: np.random.Generator
    ) -> np.ndarray:
    """
    Corrects the particle number (Hamming weight) of a single spin block (Alpha or Beta) 
    guided by the cluster's n-vector occupancy profile.

    Args:
        bitstring (np.ndarray): Boolean 1D array representing the spin block.
        n_vec (np.ndarray): Occupancy profile vector for the specific cluster.
        n_elec (int): Target number of electrons for this spin block.
        n_orb (int): Number of orbitals.
        rng (np.random.Generator): Random number generator instance for reproducibility.

    Returns:
        np.ndarray: The corrected Boolean 1D array with the target particle number.
    """
    
    # 1. Initialization
    bit_array = bitstring.copy() 
    partition_size = n_orb


    # Initialize probability vector for flipping bits
    probs_vec = np.zeros(partition_size)
    

    # 2. Calculate Flip Probabilities
    # - Determine the probability of flipping each bit based on its current state and the target n_vec
    for i in range(partition_size):
        if bit_array[i]: 
            # Probability to flip 1 -> 0 (remove electron)
            probs_vec[i] = _p_flip_1_to_0(n_elec / partition_size, n_vec[i], 0.01) # 비트가 1인데 0으로 뒤집을 확률 Relu
        else:
            # Probability to flip 0 -> 1 (add electron)
            probs_vec[i] = _p_flip_0_to_1(n_elec / partition_size, n_vec[i], 0.01)

    # Ensure probabilities are non-negative (safety check)
    probs_vec = np.absolute(probs_vec) # 안전장치

    # 3. Determine Correction Direction
    n_ones = np.sum(bit_array) # Current Hamming weight
    n_diff = n_ones - n_elec   # Difference: Positive = excess electrons, 
                               # Negative = missing electrons


    # 4. Apply Corrections

    # Case A: Too many electrons (Remove 1s)
    if n_diff > 0:
        indices_occupied = np.where(bit_array)[0]

        # Normalize probabilities for occupied bits to create a valid distribution for sampling
        p_choice = probs_vec[bit_array] / np.sum(probs_vec[bit_array]) 

        # Select bits to flip from 1 to 0 
        indices_to_flip = rng.choice(
            indices_occupied, size=round(n_diff), replace=False, p=p_choice
        )
        bit_array[indices_to_flip] = False

    # Case B: Too few electrons (Add 1s)
    elif n_diff < 0:
        indices_empty = np.where(np.logical_not(bit_array))[0]

        # Normalize probabilities for empty bits to create a valid distribution for sampling
        p_choice = probs_vec[np.logical_not(bit_array)] / np.sum(probs_vec[np.logical_not(bit_array)])
        
        # Select bits to flip from 0 to 1 
        indices_to_flip = rng.choice(
            indices_empty, size=round(np.abs(n_diff)), replace=False, p=p_choice
        )
        bit_array[indices_to_flip] = np.logical_not(bit_array[indices_to_flip])

    return bit_array



def refine_a_cluster(
    cluster_bits: np.ndarray, 
    cluster_probs: np.ndarray, 
    n_vec: np.ndarray, 
    n_elec: int, 
    n_orb: int, 
    rng: np.random.Generator
    ) -> Tuple[np.ndarray, np.ndarray]:
    """
    Refines noisy bitstrings within a specific cluster by correcting particle number violations 
    guided by the cluster's n-vector, followed by deduplication.

    Args:
        cluster_bits (np.ndarray): 2D array of noisy bitstrings belonging to the cluster (Shape: [N_k, n_orb]).
        cluster_probs (np.ndarray): 1D array of probabilities/frequencies associated with the bitstrings (Shape: [N_k,]).
        n_vec (np.ndarray): The guide occupancy vector for this cluster (Shape: [n_orb,], Sum should approx n_elec).
        n_elec (int): Target number of electrons (Hamming weight).
        n_orb (int): Number of spatial orbitals.
        rng (np.random.Generator): Random number generator instance.

    Returns:
        Tuple[np.ndarray, np.ndarray]:
            - refined_unique_bits (np.ndarray): Unique, corrected basis bitstrings with valid particle numbers.
            - refined_unique_probs (np.ndarray): Aggregated probability distribution for the corrected basis.

    """
    refined_list = []
    
    # 1. Bitstring Correction Loop
    # - Iterate through each noisy bitstring in the cluster and apply correction
    for i in range(len(cluster_bits)):

        # Correct the bitstring to match target particle number using the cluster's n-vector as a guide
        corrected = _bipartite_bitstring_correcting_cluster(
            cluster_bits[i], n_vec, n_elec, n_orb, rng
        )

        refined_list.append(corrected)
    
    refined_matrix = np.array(refined_list)
    
    # 2. Pooling (Deduplication & Aggregation)
    # - Since multiple noisy states may collapse into the same valid physical state, 
    # - we deduplicate and sum their probabilities.
    unique_bits, inverse_indices = np.unique(
        refined_matrix, axis=0, return_inverse=True
    )

    unique_probs = np.zeros(len(unique_bits))
    np.add.at(unique_probs, inverse_indices, cluster_probs)
    
    return unique_bits, unique_probs



def subsample_cluster(
    refined_clusters: list[Tuple[np.ndarray, np.ndarray]], 
    samples_per_cluster: Tuple[int, ...], 
    n_orb: int, 
    rng: np.random.Generator
    ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
    """
    Performs probability-based sampling from refined cluster pools 
    and constructs a unified basis batch with a membership matrix for n-vector updates.

    Args:
        refined_clusters (List[Tuple[np.ndarray, np.ndarray]]): List of tuples, where each tuple contains 
            (refined_bitstrings, probabilities) for a specific cluster.
        samples_per_cluster (Tuple[int, ...]): Target number of samples to extract from each cluster.
        n_orb (int): Number of spatial orbitals.
        rng (np.random.Generator): Random number generator instance for reproducibility.

    Returns:
        Tuple[np.ndarray, np.ndarray, np.ndarray]:
            - unique_batch (np.ndarray): Deduplicated unified basis matrix of shape (N_unique, n_orb).
            - unique_probs (np.ndarray): Aggregated probability weights for each unique basis.
            - membership_matrix (np.ndarray): Boolean matrix indicating cluster lineage. 
              Shape: (N_unique, n_clusters). Element [i, k] is True if basis 'i' originated from cluster 'k'.
    """

    # 1. Initialization
    all_sampled_strs = []      
    all_sampled_probs = []     
    all_sampled_cluster_ids = [] # Tracks cluster origin for each sample
    
    # The total number of samples to be drawn across all clusters
    total_target_count = sum(samples_per_cluster)
    n_clusters = len(refined_clusters) 

    # 2. Cluster-wise Sampling Loop 
    for k in range(n_clusters): 
        ref_bits, ref_probs = refined_clusters[k] # refined bitstrings and their probabilities for this cluster
        target_count = int(samples_per_cluster[k]) # Target number of samples to draw from this cluster
        
        # Calculate global weight of this cluster relative to the total batch size
        cluster_global_weight = target_count / total_target_count
        
        # Normalize intra-cluster probabilities to sum to 1.0 for sampling
        p_intra = ref_probs / np.sum(ref_probs)
        
        # Perform weighted sampling without replacement
        indices = rng.choice(len(ref_bits), size=target_count, replace=False, p=p_intra)
        
        # Scale probabilities by global weight
        # - This ensures the total probability mass across all clusters sums to 1.0
        scaled_probs = p_intra[indices] * cluster_global_weight
        
        all_sampled_strs.append(ref_bits[indices])      # 선택된 비트스트링 저장
        all_sampled_probs.append(scaled_probs)           # 스케일링된 확률 저장
        all_sampled_cluster_ids.append(np.full(target_count, k)) # 출신 성분 기록

    # 3. Aggregation
    merged_strs = np.vstack(all_sampled_strs)
    merged_probs = np.concatenate(all_sampled_probs)
    merged_ids = np.concatenate(all_sampled_cluster_ids)

    # 4. Deduplication & Probability Aggregation
    unique_batch, inverse_indices = np.unique(merged_strs, axis=0, return_inverse=True)
    
    # Accumulate probabilities for duplicate states
    unique_probs = np.zeros(len(unique_batch))
    np.add.at(unique_probs, inverse_indices, merged_probs)
    
    # 5. Membership Matrix Construction
    membership_matrix = np.zeros((len(unique_batch), n_clusters), dtype=bool)
    
    # Map original sampled indices to unique basis locations
    for original_idx, unique_idx in enumerate(inverse_indices):
        cluster_id = merged_ids[original_idx]
        membership_matrix[unique_idx, cluster_id] = True
        
    return unique_batch, unique_probs, membership_matrix




def make_batches(
    n_batch: int,
    best_result: Tuple[Tuple[np.ndarray, np.ndarray], np.ndarray, float, np.ndarray],
    refined_clusters: list[Tuple[np.ndarray, np.ndarray]],
    samples_per_cluster: Tuple[int, ...],
    n_orb: int,
    threshold: float,
    max_dim: Optional[int],
    base_seed: Optional[int] = None
    ) -> list[Tuple[np.ndarray, np.ndarray, np.ndarray]]:
    """
    Constructs diagonalization batches by integrating significant basis states from the previous iteration (Carry-over)
    with newly refined samples from the current iteration.

    Args:
        n_batch (int): Total number of batches to generate.
        best_result (Tuple): Tuple containing (ci_strings, coef, energy, membership_matrix) from the previous step.
        refined_clusters (List): List of (bits, probs) tuples for each cluster.
        samples_per_cluster (Tuple[int, ...]): Target number of samples to extract per cluster.
        n_orb (int): Number of spatial orbitals.
        threshold (float): Absolute coefficient threshold to select carry-over basis states.
        max_dim (Optional[int]): Maximum number of unique basis states per batch (None for no limit).
        base_seed (Optional[int]): Base random seed.

    Returns:
        List[Tuple[np.ndarray, np.ndarray, np.ndarray]]: A list of tuples, where each tuple contains:
            - unique_batch (np.ndarray): The unified, deduplicated basis set.
            - unique_probs (np.ndarray): Aggregated probability weights.
            - unique_membership (np.ndarray): Merged cluster membership matrix.
    
    """

    # =========================================================================
    # SECTION 1: Extract Carry-over Basis & Weights
    # =========================================================================
    ci_strs, coef, _, prev_membership = best_result
    alpha_ints = ci_strs[0] 
    abs_coef = np.abs(coef)
    
    # 1-A. Identify Important Basis States
    # - Check if coefficients exceed the threshold in either alpha or beta components
    important_alpha = np.any(abs_coef > threshold, axis=1) 
    important_beta = np.any(abs_coef > threshold, axis=0)  
    important_mask = important_alpha | important_beta    
    
    carry_over_ints = alpha_ints[important_mask]
    carry_over_membership = prev_membership[important_mask]
    
    # 1-B. Calculate Weights for Carry-over
    carry_over_weights = np.sum(np.abs(coef)**2, axis=1)[important_mask]
    
    # Calculate importance scores for potential truncation (Sum of absolute coefs)
    importance_scores = np.sum(abs_coef, axis=1)[important_mask] + np.sum(abs_coef, axis=0)[important_mask]
    
    # 1-C. Apply Max Dimension Constraint to Carry-over
    if max_dim is not None and len(carry_over_ints) > max_dim:
        top_indices = np.argsort(importance_scores)[-max_dim:]
        carry_over_ints = carry_over_ints[top_indices]
        carry_over_membership = carry_over_membership[top_indices]
        carry_over_weights = carry_over_weights[top_indices]

    # 1-D. Convert Integers to Bitstrings
    def ints_to_bits(ints, n_orb):
        return np.array([[(i >> p) & 1 for p in range(n_orb-1, -1, -1)] for i in ints], dtype=bool)
    
    carry_over_bits = ints_to_bits(carry_over_ints, n_orb)
    
    all_batches = []
    n_clusters = len(refined_clusters)


    # =========================================================================
    # SECTION 2: Batch Generation Loop
    # =========================================================================
    for b_idx in range(n_batch):
        rng = np.random.default_rng(base_seed + b_idx if base_seed is not None else None)
        
        # 2-A. Generate New Samples
        # - Extract fresh samples from the refined clusters using the subsampling logic
        new_bits, new_probs, new_membership = subsample_cluster(
            refined_clusters, samples_per_cluster, n_orb, rng
        )
        
        # 2-B. Merge Data
        # - Combine carry-over basis with the new samples
        combined_bits = np.vstack([carry_over_bits, new_bits])
        combined_probs = np.concatenate([carry_over_weights, new_probs])
        combined_membership = np.vstack([carry_over_membership, new_membership])
        
        # 2-C. Deduplication & Probability Aggregation
        unique_batch, inv_idx = np.unique(combined_bits, axis=0, return_inverse=True)
        
        unique_probs = np.zeros(len(unique_batch))
        np.add.at(unique_probs, inv_idx, combined_probs)
        
        # 2-D. Membership Aggregation
        # - Merge cluster membership flags using bitwise OR logic
        unique_membership = np.zeros((len(unique_batch), n_clusters), dtype=bool)
        for i, target_idx in enumerate(inv_idx):
            unique_membership[target_idx] |= combined_membership[i]
            
        # 2-E. Sort & Prune
        # - Sort by probability (importance) in descending order to prioritize high-weight states
        sorted_indices = np.argsort(unique_probs)[::-1]
        unique_batch = unique_batch[sorted_indices]
        unique_probs = unique_probs[sorted_indices]
        unique_membership = unique_membership[sorted_indices]
        

        # Apply strict max_dim constraint to the final batch
        if max_dim is not None and len(unique_batch) > max_dim:
            unique_batch = unique_batch[:max_dim]
            unique_probs = unique_probs[:max_dim]
            unique_membership = unique_membership[:max_dim]
            
        # 2-F. Store the Batch
        all_batches.append((unique_batch, unique_probs, unique_membership))
            
    return all_batches