Source code for aiida_trains_pot.data.pesdata

from aiida.orm import Data
import tempfile
import os
import numpy as np
from ase.calculators.singlepoint import SinglePointCalculator
from ase import Atoms
import io
from contextlib import redirect_stdout
from ase.io import write
import warnings
import h5py
import re
from ase.data import atomic_numbers

def convert_stress(stress):
    if len(np.shape(stress)) == 1:
        if len(stress) == 6:
            stress = np.array([[stress[0], stress[5], stress[4]],
                               [stress[5], stress[1], stress[3]],
                               [stress[4], stress[3], stress[2]]])
        else:
            stress = np.array([[stress[0], stress[1], stress[2]],
                               [stress[3], stress[4], stress[5]],
                               [stress[6], stress[7], stress[8]]])
    else:
        stress = stress
    return stress


[docs] class PESData(Data): @property def _list_key(self): """Generate a unique filename for the list based on the node's UUID.""" return f"psedata_{self.uuid}.npz" # Unique filename with the node's UUID def __init__(self, data=None, save_labels=True, **kwargs): """ Initialize a PESData instance. :param data: Optional list of data to initialize the PESData node. :param kwargs: Additional keyword arguments passed to the parent Data class. """ super().__init__(**kwargs) # Initialize the parent class if data: if isinstance(data[0], Atoms): self.set_ase(data, save_labels=save_labels) else: self.set_list(data) def __iter__(self): """Return an iterator over the dataset.""" self._index = 0 self._max_index = self.base.attributes.get('dataset_size', 0) if self._max_index == 0: # Count the number of items in the HDF5 file try: with self.base.repository.open(self._list_key, 'rb') as hdf_file: with h5py.File(hdf_file, 'r') as hdf: self._max_index = len(hdf) except (FileNotFoundError, KeyError): self._max_index = 0 return self def __next__(self): """Return the next item from the dataset.""" if self._index < self._max_index: result = self.get_item(self._index) self._index += 1 return result else: raise StopIteration def __add__(self, other): if not isinstance(other, PESData): raise TypeError(f"Cannot add {type(other)} to PESData") return self.__iadd__(other) def __iadd__(self, other): """Support the += operation for combining two PESData objects by creating a new node.""" if not isinstance(other, PESData): raise TypeError(f"Cannot add {type(other)} to PESData") return PESData(data = self.get_list() + other.get_list()) def __len__(self): """Return the number of configurations in the dataset.""" try: if self.base.attributes.get('dataset_size'): return self.base.attributes.get('dataset_size') else: return len(self.get_list()) except: return 0 def _extract_config_from_group(self, group): """ Helper method to extract configuration data from an HDF5 group. :param group: HDF5 group object containing configuration data :return: Dictionary with the configuration data """ config = {} # Extract datasets for key, value in group.items(): # Convert datasets to lists or native Python types if isinstance(value, h5py.Dataset): config[key] = value[()].tolist() if hasattr(value[()], 'tolist') else value[()] else: config[key] = value[()] # Add attributes to the config for attr_key, attr_value in group.attrs.items(): config[attr_key] = attr_value # Ensure symbols are decoded properly if 'symbols' in config: config['symbols'] = [str(symbol, 'utf-8') if isinstance(symbol, bytes) else str(symbol) for symbol in config['symbols']] return config
[docs] def get_atomic_species(self): """Return the list of atomic species in the dataset.""" try: return self.base.attributes.get('atomic_species', []) except Exception as e: print(f"An error occurred while retrieving atomic species: {e}") return []
[docs] def get_atomic_numbers(self): """Return the list of atomic numbers in the dataset.""" atomic_species = self.get_atomic_species() atomic_nums = [atomic_numbers[symbol] for symbol in atomic_species if symbol in atomic_numbers] return atomic_nums
[docs] def get_item(self, index): """Return a specific item from the dataset by index.""" try: with self.base.repository.open(self._list_key, 'rb') as hdf_file: with h5py.File(hdf_file, 'r') as hdf: group_key = f"item_{index}" if group_key not in hdf: raise IndexError(f"Index {index} out of range") return self._extract_config_from_group(hdf[group_key]) except FileNotFoundError as e: print(f"File '{self._list_key}' not found: {e}") raise IndexError(f"Index {index} out of range") from e except Exception as e: print(f"An error occurred while reading '{self._list_key}': {e}") raise
[docs] def iter_items(self): """Generator function to iterate through items without loading all into memory.""" try: with self.base.repository.open(self._list_key, 'rb') as hdf_file: with h5py.File(hdf_file, 'r') as hdf: for group_key in sorted(hdf.keys(), key=lambda k: int(k.split('_')[1])): yield self._extract_config_from_group(hdf[group_key]) except FileNotFoundError as e: print(f"File '{self._list_key}' not found: {e}") except Exception as e: print(f"An error occurred while reading '{self._list_key}': {e}")
[docs] def get_list(self, max_items=None, warn_threshold=1000): """ Return the contents of this node as a list. :param max_items: Optional limit to the number of items to load :param warn_threshold: Show warning if loading more than this many items """ n_items = self.base.attributes.get('dataset_size', 0) if n_items == 0: # Count the number of items in the HDF5 file try: with self.base.repository.open(self._list_key, 'rb') as hdf_file: with h5py.File(hdf_file, 'r') as hdf: n_items = len(hdf) except (FileNotFoundError, KeyError): n_items = 0 if max_items is None: max_items = n_items else: max_items = min(max_items, n_items) if max_items > warn_threshold: warnings.warn(f"Loading {max_items} items into memory. This may consume a significant amount of RAM. " f"Consider using iter_items() for memory-efficient iteration.", UserWarning) # Use the iterator to build the list, which is more consistent and reuses code data = [] for i, config in enumerate(self.iter_items()): if i >= max_items: break data.append(config) return data
[docs] def set_ase(self, data, save_labels=True): """ Set the contents of this node by saving a list of ASE Atoms objects as an HDF5 file. :param data: A list of ASE Atoms objects to save. """ from ase.calculators.singlepoint import SinglePointDFTCalculator as dft_calc from ase.calculators.singlepoint import SinglePointCalculator as single_calc # Ensure data is a list of Atoms objects if not isinstance(data, list): raise TypeError("Input data must be a list of ase.atoms.Atoms.") else: for item in data: if not isinstance(item, Atoms): raise TypeError("Input data must be a list of ase.atoms.Atoms.") num_labelled_frames = 0 num_unlabelled_frames = 0 symb = set() save_data = [] for atm in data: info = atm.info if save_labels and (isinstance(atm.calc, dft_calc) or isinstance(atm.calc, single_calc)): num_labelled_frames += 1 save_data.append({'cell': atm.cell, 'symbols': atm.get_chemical_symbols(), 'positions': atm.get_positions(), 'pbc': atm.pbc, 'dft_energy': atm.calc.results['energy'], 'dft_forces': atm.calc.results['forces']}) symb = symb.union(set(save_data[-1]['symbols'])) for key, value in info.items(): if 'energy' not in key: save_data[-1][key] = value try: stress = atm.get_stress(voigt=False) save_data[-1]['dft_stress'] = stress except: continue else: num_unlabelled_frames += 1 save_data.append({'cell': atm.cell, 'symbols': atm.get_chemical_symbols(), 'positions': atm.get_positions(), 'pbc': atm.pbc}) symb = symb.union(set(save_data[-1]['symbols'])) for key, value in info.items(): if 'energy' not in key: save_data[-1][key] = value try: # Create a temporary directory to save the file with tempfile.TemporaryDirectory() as temp_dir: dataset_temp_file = os.path.join(temp_dir, f"{self._list_key}.h5") # Save the data using h5py with h5py.File(dataset_temp_file, 'w') as hdf: for idx, item in enumerate(save_data): group = hdf.create_group(f"item_{idx}") for key, value in item.items(): if isinstance(value, list) or isinstance(value, np.ndarray): group.create_dataset(key, data=value) else: group.attrs[key] = value # Store the file in the AiiDA repository self.base.repository.put_object_from_file(dataset_temp_file, self._list_key) # Store metadata as attributes self.base.attributes.set('dataset_size', len(data)) self.base.attributes.set('num_labelled_frames', num_labelled_frames) self.base.attributes.set('num_unlabelled_frames', num_unlabelled_frames) self.base.attributes.set('atomic_species', list(symb)) except Exception as e: print(f"An error occurred while saving '{self._list_key}': {e}")
[docs] def set_list(self, data): """Set the contents of this node by saving a list as an HDF5 file.""" # Ensure data is a list if not isinstance(data, list): raise TypeError("Input data must be a list.") num_labelled_frames = 0 num_unlabelled_frames = 0 symb = set() for item in data: if "dft_forces" in item.keys() and "dft_energy" in item.keys(): num_labelled_frames += 1 else: num_unlabelled_frames += 1 symb = symb.union(set(item['symbols'])) save_data = [] for item in data: if 'pbc' not in item: item['pbc'] = [True, True, True] warnings.warn("Periodic boundary conditions not found in the dataset. Assuming PBC = [True, True, True].", UserWarning) if 'cell' not in item: raise ValueError("Cell vectors not found in the dataset.") if 'symbols' not in item: raise ValueError("Atomic symbols not found in the dataset.") if 'positions' not in item: raise ValueError("Atomic positions not found in the dataset.") # Ensure that symbols are a list of strings item['symbols'] = [str(symbol) for symbol in item['symbols']] save_data.append({key: value.tolist() if isinstance(value, np.ndarray) else value for key, value in item.items()}) try: # Create a temporary directory to save the file with tempfile.TemporaryDirectory() as temp_dir: dataset_temp_file = os.path.join(temp_dir, f"{self._list_key}.h5") # Save the data using h5py with h5py.File(dataset_temp_file, 'w') as hdf: for idx, item in enumerate(save_data): group = hdf.create_group(f"item_{idx}") for key, value in item.items(): if isinstance(value, list) or isinstance(value, np.ndarray): group.create_dataset(key, data=value) else: group.attrs[key] = value # Store the file in the AiiDA repository self.base.repository.put_object_from_file(dataset_temp_file, self._list_key) # Store metadata as attributes self.base.attributes.set('dataset_size', len(data)) self.base.attributes.set('num_labelled_frames', num_labelled_frames) self.base.attributes.set('num_unlabelled_frames', num_unlabelled_frames) self.base.attributes.set('atomic_species', list(symb)) except Exception as e: print(f"An error occurred while saving '{self._list_key}': {e}")
def _config_to_ase(self, config): """ Helper method to convert a configuration dictionary to an ASE Atoms object. :param config: Dictionary containing atomic configuration data :return: ASE Atoms object """ atoms = Atoms( symbols=config['symbols'], positions=config['positions'], cell=config['cell'], pbc=config['pbc'] ) # Add calculator with DFT data if available if 'dft_energy' in config and 'dft_forces' in config: calc_kwargs = { 'energy': config['dft_energy'], 'forces': config['dft_forces'] } if 'dft_stress' in config: calc_kwargs['stress'] = convert_stress(config['dft_stress']) atoms.set_calculator(SinglePointCalculator(atoms, **calc_kwargs)) return atoms
[docs] def get_ase_item(self, index): """ Get a specific configuration as an ASE Atoms object. :param index: Index of the configuration to retrieve :return: ASE Atoms object """ config = self.get_item(index) return self._config_to_ase(config)
def get_ase(self, index=None): if index is not None: return self.get_ase_item(index) else: return self.get_ase_list()
[docs] def get_ase_list(self, max_items=None, warn_threshold=1000): """ Convert dataset to a list of ASE Atoms objects. :param max_items: Optional limit to the number of items to load :param warn_threshold: Show warning if loading more than this many items :return: List of ASE Atoms objects """ n_items = self.base.attributes.get('dataset_size', 0) if n_items == 0: # Count the number of items in the HDF5 file try: with self.base.repository.open(self._list_key, 'rb') as hdf_file: with h5py.File(hdf_file, 'r') as hdf: n_items = len(hdf) except (FileNotFoundError, KeyError): n_items = 0 if max_items is None: max_items = n_items else: max_items = min(max_items, n_items) if max_items > warn_threshold: warnings.warn(f"Loading {max_items} items into memory. This may consume a significant amount of RAM.", UserWarning) ase_list = [] for i, config in enumerate(self.iter_items()): if i >= max_items: break ase_list.append(self._config_to_ase(config)) return ase_list
def get_txt(self, write_params=False, key_prefix='', max_items=None, warn_threshold=1000): return self.get_xyz(write_params, key_prefix, max_items, warn_threshold)
[docs] def get_xyz(self, write_params=False, key_prefix='', max_items=None, warn_threshold=1000): """ Convert dataset to XYZ format text. :param write_params: Whether to include additional parameters in the output :param key_prefix: Prefix to add to property keys (energy, forces, stress) :param max_items: Optional limit to the number of items to process :param warn_threshold: Show warning if processing more than this many items :return: Text in XYZ format """ n_items = self.base.attributes.get('dataset_size', 0) if n_items == 0: # Count the number of items in the HDF5 file try: with self.base.repository.open(self._list_key, 'rb') as hdf_file: with h5py.File(hdf_file, 'r') as hdf: n_items = len(hdf) except (FileNotFoundError, KeyError): n_items = 0 if max_items is None: max_items = n_items else: max_items = min(max_items, n_items) if max_items > warn_threshold: warnings.warn(f"Processing {max_items} items. This may take some time and consume memory.", UserWarning) dataset_txt = '' if not key_prefix.endswith('_') and key_prefix != '': key_prefix += '_' exclude_params = ['cell', 'symbols', 'positions', 'pbc', 'forces', 'stress', 'energy', 'dft_forces', 'dft_stress', 'dft_energy', 'md_forces', 'md_stress', 'md_energy'] exclude_pattern = re.compile(r'pot_\d+_(energy|forces|stress)') for i, config in enumerate(self.iter_items()): if i >= max_items: break params = [key for key in config.keys() if key not in exclude_params and not exclude_pattern.match(key)] atm = Atoms(symbols=config['symbols'], positions=config['positions'], cell=config['cell'], pbc=config['pbc']) atm.pbc = config['pbc'] atm.info = {} if write_params: for key in params: atm.info[key] = config[key] if len(atm.get_chemical_symbols()) == 1 and atm.get_pbc().all() == False: atm.info["config_type"] = "IsolatedAtom" if 'dft_stress' in config: s = convert_stress(config['dft_stress']) atm.info[f'{key_prefix}stress'] = f"{s[0][0]:.6f} {s[0][1]:.6f} {s[0][2]:.6f} {s[1][0]:.6f} {s[1][1]:.6f} {s[1][2]:.6f} {s[2][0]:.6f} {s[2][1]:.6f} {s[2][2]:.6f}" if 'dft_energy' in config: atm.info[f'{key_prefix}energy'] = config['dft_energy'] if 'dft_forces' in config: atm.set_calculator(SinglePointCalculator(atm, forces=config['dft_forces'])) print(atm.info.get('stress', None)) with io.StringIO() as buf, redirect_stdout(buf): write('-', atm, format='extxyz', write_results=True, write_info=True) dataset_txt += buf.getvalue() if key_prefix != '': dataset_txt = dataset_txt.replace(':forces', f':{key_prefix}forces') return dataset_txt
[docs] def get_unlabelled(self): """Return a PESData object with only unlabelled frames.""" if self.base.attributes.get('num_labelled_frames') == 0: return self unlabelled_data = [] for config in self.iter_items(): if 'dft_forces' not in config.keys() or 'dft_energy' not in config.keys(): unlabelled_data.append(config) return PESData(data=unlabelled_data)
[docs] def get_labelled(self): """Return a PESData object with only labelled frames.""" if self.base.attributes.get('num_unlabelled_frames') == 0: return self labelled_data = [] for config in self.iter_items(): if 'dft_forces' in config.keys() and 'dft_energy' in config.keys(): labelled_data.append(config) return PESData(data=labelled_data)
@property def len_unlabelled(self): """Return the number of unlabelled configurations in the dataset.""" if "num_unlabelled_frames" in self.base.attributes.keys(): return self.base.attributes.get('num_unlabelled_frames') else: return len(self.get_unlabelled()) @property def len_labelled(self): """Return the number of labelled configurations in the dataset.""" if "num_labelled_frames" in self.base.attributes.keys(): return self.base.attributes.get('num_labelled_frames') else: return len(self.get_labelled())
[docs] def iter_ase(self, max_items=None): """ Generator function to iterate through ASE Atoms objects without loading all into memory. :param max_items: Optional limit to the number of items to process :yield: ASE Atoms objects one at a time """ if max_items is not None: counter = 0 for config in self.iter_items(): if counter >= max_items: break yield self._config_to_ase(config) counter += 1 else: for config in self.iter_items(): yield self._config_to_ase(config)
[docs] def get_e0s(self, format='atomic_numbers') -> dict: """ Get the energies of isolated atoms from the dataset. :param format: Format of the atomic species ('atomic_numbers' or 'symbols') :return: dictionary of isolated atom energies """ # e0s = {atomic_numbers[k]: None for k in self.get_atomic_species()} e0s = {k: None for k in self.get_atomic_species()} for config in self.iter_items(): if not all(config['pbc']) and len(config['symbols']) == 1 and 'dft_energy' in config.keys(): e0s[config['symbols'][0]] = config['dft_energy'] if format == 'atomic_numbers': e0s = {atomic_numbers[k]: v for k, v in e0s.items()} return e0s