Source code for aiida_trains_pot.mace.mace_train_plugin.calculations

"""
AiiDA calculation plugin for the MACE training code.
"""
import re
from aiida.common import datastructures
from aiida.engine import CalcJob
from aiida.orm import SinglefileData, List, FolderData, Dict, Bool, Code, Str
import io
from contextlib import redirect_stdout
from ase.io import write
from ase.calculators.singlepoint import SinglePointCalculator
from ase import Atoms
from aiida.plugins import DataFactory
import random
import yaml

PESData = DataFactory('pesdata')

def validate_protocol(node, _):
    """Validate the protocol input."""
    if node.value not in ["naive-finetune", "replay-finetune"]:
        return "The `protocol` input can only be 'naive-finetune' or 'replay-finetune'."

def validate_inputs(inputs, _):
    """Validate the top-level inputs."""
    if 'protocol' in inputs:
        if inputs["protocol"].value == "naive-finetune" and "finetune_model" not in inputs:
            return "The `finetune_model` input is required when using the 'naive-finetune' protocol."
        if inputs["protocol"].value == "replay-finetune":
            if "finetune_model" not in inputs:
                return "The `finetune_model` input is required when using the 'replay-finetune' protocol."
            if "finetune_replay_dataset" not in inputs:
                return "The `finetune_replay_dataset` input is required when using the 'replay-finetune' protocol."

[docs] class MaceTrainCalculation(CalcJob): """ AiiDA calculation plugin wrapping the diff executable. Simple AiiDA plugin wrapper for 'diffing' two files. """
[docs] @classmethod def define(cls, spec): """Define inputs and outputs of the calculation.""" super().define(spec) # set default values for AiiDA options spec.inputs["metadata"]["options"]["resources"].default = {"num_machines": 1, "num_mpiprocs_per_machine": 1,} spec.inputs["metadata"]["options"]["parser_name"].default = "trains_pot.macetrain" #ADD input parameters # new ports spec.input("training_set", valid_type=PESData, help="Training dataset list",) spec.input("validation_set", valid_type=PESData, help="Validation dataset list",) spec.input("test_set", valid_type=PESData, help="Test dataset list",) spec.input("mace_config", valid_type=Dict, help="Config parameters for MACE",) spec.input("checkpoints", valid_type=FolderData, help="Checkpoints file", required=False) spec.input("do_preprocess", valid_type=Bool, help="Perform preprocess", required=False, default=lambda:Bool(False)) spec.input("preprocess_code", valid_type=Code, help="Preprocess code, required if do_preprocess is True", required=False) spec.input("postprocess_code", valid_type=Code, help="Postprocess code", required=False) spec.input("restart", valid_type=Bool, help="Restart from a previous calculation", required=False, default=lambda:Bool(False)) spec.input("checkpoints_restart", valid_type=FolderData, help="Checkpoints file", required=False) spec.input("protocol", valid_type=Str, help="Protocol for the calculation {'naive-finetune' or 'replay-finetune'}", required=False, validator=validate_protocol) spec.input("finetune_model", valid_type=SinglefileData, help="Model to finetune", required=False) spec.input("finetune_replay_dataset", valid_type=PESData, help="Dataset for replay finetune", required=False) spec.inputs.validator = validate_inputs spec.output("model_stage1_lammps", valid_type=SinglefileData, help="Stage 1 model compiled for LAMMPS",) spec.output("model_stage1_ase", valid_type=SinglefileData, help="Stage 1 model compiled for ASE",) spec.output("model_stage1_pytorch", valid_type=SinglefileData, help="Stage 1 model not compiled",) spec.output("model_stage2_lammps", valid_type=SinglefileData, help="Stage 2 model compiled for LAMMPS",) spec.output("model_stage2_ase", valid_type=SinglefileData, help="Stage 2 model compiled for ASE",) spec.output("model_stage2_pytorch", valid_type=SinglefileData, help="Stage 2 model not compiled",) spec.output("mace_out", valid_type=SinglefileData, help="Mace output file",) spec.output("logs", valid_type=FolderData, help="Logs file",) spec.output("checkpoints", valid_type=FolderData, help="Checkpoints file",) spec.output("RMSE", valid_type=List, help="List of the checkpoints result table",) spec.output("results", valid_type=List, help="List of the results of log file",) spec.exit_code(300, "ERROR_MISSING_OUTPUT_FILES", message="Calculation did not produce all expected output files.",) spec.exit_code(400, "ERROR_OUT_OF_WALLTIME", message="The calculation stopped prematurely because it ran out of walltime.", )
[docs] def prepare_for_submission(self, folder): """ Create input files. :param folder: an `aiida.common.folders.Folder` where the plugin should temporarily place all files needed by the calculation. :return: `aiida.common.datastructures.CalcInfo` instance """ mace_config_dict = self.inputs.mace_config.get_dict() do_preprocess = self.inputs.do_preprocess.value if do_preprocess: if 'preprocess_code' not in self.inputs: raise ValueError("Preprocess code is required if do_preprocess is True") preprocess_code = self.inputs.preprocess_code if do_preprocess: codeinfo_preprocess = datastructures.CodeInfo() codeinfo_preprocess.code_uuid = preprocess_code.uuid codeinfo_preprocess.cmdline_params = [ '--train_file', "training.xyz", '--valid_file', "validation.xyz", '--test_file', "test.xyz", '--energy_key', "dft_energy", '--forces_key', "dft_forces", '--stress_key', "dft_stress", '--compute_statistics', '--h5_prefix', "processed_data/", '--seed', str(random.randint(0, 10000)) ] if 'r_max' in mace_config_dict: codeinfo_preprocess.cmdline_params += ['--r_max', str(mace_config_dict['r_max'])] # for MACE < 0.3.7 codeinfo_postprocess1 = datastructures.CodeInfo() codeinfo_postprocess1.code_uuid = self.inputs.postprocess_code.uuid codeinfo_postprocess1.cmdline_params = ["aiida_swa.model"] # for MACE >= 0.3.7 codeinfo_postprocess1b = datastructures.CodeInfo() codeinfo_postprocess1b.code_uuid = self.inputs.postprocess_code.uuid codeinfo_postprocess1b.cmdline_params = ["aiida_stagetwo.model"] codeinfo_postprocess2 = datastructures.CodeInfo() codeinfo_postprocess2.code_uuid = self.inputs.postprocess_code.uuid codeinfo_postprocess2.cmdline_params = ["aiida.model"] codeinfo = datastructures.CodeInfo() codeinfo.cmdline_params =f"""--config config.yml""".split() codeinfo.code_uuid = self.inputs.code.uuid codeinfo.stdout_name = "mace.out" training_txt = self.inputs.training_set.get_txt(write_params=False, key_prefix='dft') validation_txt = self.inputs.validation_set.get_txt(write_params=False, key_prefix='dft') test_txt = self.inputs.test_set.get_txt(write_params=False, key_prefix='dft') with folder.open('training.xyz', "w") as handle: handle.write(training_txt) with folder.open('validation.xyz', "w") as handle: handle.write(validation_txt) with folder.open('test.xyz', "w") as handle: handle.write(test_txt) mace_config_dict['seed'] = random.randint(0, 10000) if do_preprocess: mace_config_dict['train_file'] = "processed_data/train/" mace_config_dict['valid_file'] = "processed_data/val/" mace_config_dict['test_file'] = "processed_data/test/" mace_config_dict['statistics_file'] = "processed_data/statistics.json" else: mace_config_dict['train_file'] = "training.xyz" mace_config_dict['valid_file'] = "validation.xyz" mace_config_dict['test_file'] = "test.xyz" mace_config_dict['energy_key'] = "dft_energy" mace_config_dict['forces_key'] = "dft_forces" mace_config_dict['stress_key'] = "dft_stress" if 'E0s' not in mace_config_dict: e0s = self.inputs.training_set.get_e0s() if None not in e0s.values(): mace_config_dict['E0s'] = str(e0s) else: atomic_numbers = self.inputs.training_set.get_atomic_numbers() if do_preprocess: codeinfo_preprocess.cmdline_params += ['--E0s=average'] codeinfo_preprocess.cmdline_params += [f'--atomic_numbers={str(atomic_numbers)}'] else: mace_config_dict['E0s'] = "average" mace_config_dict['atomic_numbers'] = f'"{str(atomic_numbers)}"' finetune = False if 'protocol' in self.inputs: finetune = True if self.inputs.protocol.value == "naive-finetune": mace_config_dict['foundation_model'] = "finetune_model.dat" mace_config_dict['multiheads_finetuning'] = False if self.inputs.protocol.value == "replay-finetune": mace_config_dict['foundation_model'] = "finetune_model.dat" mace_config_dict['multiheads_finetuning'] = True replay_txt = self.inputs.finetune_replay_dataset.get_txt(write_params=False, key_prefix='dft') with folder.open('replay.xyz', "w") as handle: handle.write(replay_txt) mace_config_dict['pt_train_file'] = "replay.xyz" if 'E0s' in mace_config_dict and mace_config_dict['E0s'] == "average": del mace_config_dict['E0s'] if 'checkpoints' in self.inputs: mace_config_dict['restart_latest'] = True # for training_structure in self.inputs.training_set: # training_dict = dict(training_structure) # if len(training_dict['symbols']) != 1: # mace_config_dict['E0s'] = "average" # break with folder.open('config.yml', 'w') as yaml_file: yaml.dump(mace_config_dict, yaml_file, default_flow_style=False) if not mace_config_dict.get('distributed', False) and self.inputs["metadata"]["options"]["resources"].get('num_mpiprocs_per_machine') > 1: mace_config_dict['distributed'] = True # Save the checkpoints folder if 'checkpoints' in self.inputs and self.inputs.restart.value==True: mace_config_dict['restart_latest'] = True checkpoints_folder = self.inputs.checkpoints folder.get_subfolder('checkpoints', create=True) # Create the checkpoints directory for checkpoint_file in checkpoints_folder.list_object_names(): if '_epoch' in checkpoint_file and '_swa': with checkpoints_folder.open(checkpoint_file, 'rb') as source: new_checkpoint_file = f"aiida_run-{str(mace_config_dict['seed'])}_epoch-0_swa.pt" with folder.open(f'checkpoints/{new_checkpoint_file}', 'wb') as destination: destination.write(source.read()) elif '_epoch' in checkpoint_file: with checkpoints_folder.open(checkpoint_file, 'rb') as source: new_checkpoint_file = f"aiida_run-{str(mace_config_dict['seed'])}_epoch-0.pt" with folder.open(f'checkpoints/{new_checkpoint_file}', 'wb') as destination: destination.write(source.read()) if 'checkpoints_restart' in self.inputs: mace_config_dict['restart_latest'] = True checkpoints_folder = self.inputs.checkpoints_restart folder.get_subfolder('checkpoints', create=True) # Create the checkpoints directory for checkpoint_file in checkpoints_folder.list_object_names(): if '_epoch' in checkpoint_file and '_swa': # Regular expression to extract the seed (assumed to be numeric after the first '-') match = re.search(r'-(\d+)_', checkpoint_file) if match: mace_config_dict['seed'] = int(match.group(1)) with checkpoints_folder.open(checkpoint_file, 'rb') as source: with folder.open(f'checkpoints/{checkpoint_file}', 'wb') as destination: destination.write(source.read()) elif '_epoch' in checkpoint_file: with checkpoints_folder.open(checkpoint_file, 'rb') as source: with folder.open(f'checkpoints/{checkpoint_file}', 'wb') as destination: destination.write(source.read()) with folder.open('config.yml', 'w') as yaml_file: yaml.dump(mace_config_dict, yaml_file, default_flow_style=False) calcinfo = datastructures.CalcInfo() if finetune: calcinfo.local_copy_list = [ ( self.inputs.finetune_model.uuid, self.inputs.finetune_model.filename, "finetune_model.dat", ),] if do_preprocess: calcinfo.codes_info = [codeinfo_preprocess, codeinfo, codeinfo_postprocess1, codeinfo_postprocess1b, codeinfo_postprocess2] else: calcinfo.codes_info = [codeinfo, codeinfo_postprocess1, codeinfo_postprocess1b, codeinfo_postprocess2] calcinfo.retrieve_list = ['*model*', 'checkpoints', 'mace.out', 'results', 'logs', '_scheduler-std*'] return calcinfo