Source code for aiida_trains_pot.aiida_trains_pot_workflow.training_wc

from aiida.engine import WorkChain, ToContext, append_, calcfunction
from aiida.orm import StructureData, Dict, List, Int, Bool, FolderData
from aiida.plugins import WorkflowFactory, DataFactory
from aiida.common import AttributeDict
import random
import itertools
import time

MaceWorkChain   = WorkflowFactory('trains_pot.macetrain')
PESData         = DataFactory('pesdata')

@calcfunction
def SplitDataset(dataset):
    """Divide dataset into training, validation and test sets."""
    # data = self.inputs.dataset_list.get_list()
    data = dataset.get_list()

    exclude_list = ["energy", "cell", "stress", "forces", "symbols", "positions", "id_lammps", "input_structure_uuid", "sigma_strain"]
    # Define a function to extract the grouping key
    def check_esclude_list(string):
        for el in exclude_list:
             if el in string:
                 return False
        return True
    
    def get_grouping_key(d):
        return tuple((k, v) for k, v in d.items() if check_esclude_list(k))

    # Sort the data based on the grouping key
    sorted_data = sorted(data, key=get_grouping_key)

    # Group the sorted data by the grouping key
    grouped_data = itertools.groupby(sorted_data, key=get_grouping_key)

    # Iterate over the groups and print the group key and the list of dictionaries in each group
    training_set = []
    validation_set = []
    test_set = []

    for _, group in grouped_data:
    # Calculate the number of elements for each set
        group_list = list(group)
        if 'gen_method' in group_list[0].keys():
            if group_list[0]['gen_method'] == "INPUT_STRUCTURE" or group_list[0]['gen_method'] == "ISOLATED_ATOM" or len(group_list[0]['positions']) == 1 or group_list[0]['gen_method'] == "EQUILIBRIUM":
                    training_set += group_list
                    continue
        if 'set' in group_list[0].keys():
            if group_list[0]['set'] == 'TRAINING':
                training_set += group_list
                continue
            elif group_list[0]['set'] == 'VALIDATION':
                validation_set += group_list
                continue
            elif group_list[0]['set'] == 'TEST':
                test_set += group_list
                continue
        total_elements = len(group_list)
        training_size = round(0.8 * total_elements)
        
        
        random.seed(int(time.time()))
        _ = random.shuffle(group_list)


        # Split the data into sets
        training_set += group_list[:training_size]
        validation_set += group_list[training_size:][::2]
        test_set +=group_list[training_size:][1::2]

    for ii in range(len(training_set)):
        training_set[ii]['set'] = 'TRAINING'
        if 'gen_method' not in training_set[ii].keys():
            training_set[ii]['gen_method'] = 'UNKNOWN'
    for ii in range(len(validation_set)):
        validation_set[ii]['set'] = 'VALIDATION'
        if 'gen_method' not in validation_set[ii].keys():
            validation_set[ii]['gen_method'] = 'UNKNOWN'
    for ii in range(len(test_set)):
        test_set[ii]['set'] = 'TEST'
        if 'gen_method' not in test_set[ii].keys():
            test_set[ii]['gen_method'] = 'UNKNOWN'

    pes_training_set = PESData()    
    pes_training_set.set_list(training_set)    

    pes_validation_set = PESData()    
    pes_validation_set.set_list(validation_set)  

    pes_test_set = PESData()    
    pes_test_set.set_list(test_set)  

    pes_global_splitted = PESData()    
    pes_global_splitted.set_list(validation_set+test_set+training_set)  
    
    return {"train_set":pes_training_set, "validation_set":pes_validation_set, "test_set":pes_test_set, "global_splitted":pes_global_splitted}


[docs] class TrainingWorkChain(WorkChain): """A workchain to loop over structures and submit MACEWorkChain."""
[docs] @classmethod def define(cls, spec): super().define(spec) spec.input("num_potentials", valid_type=Int, default=lambda:Int(1), required=False) spec.input("dataset", valid_type=PESData, help="Training dataset",) spec.input_namespace("checkpoints", valid_type=FolderData, required=False, help="Checkpoints file",) spec.expose_inputs(MaceWorkChain, namespace="mace", exclude=('train.training_set', 'train.validation_set', 'train.test_set'), namespace_options={'validator': None}) spec.output_namespace("training", dynamic=True, help="Training outputs") spec.output("global_splitted", valid_type=PESData,) spec.outline( cls.run_training, cls.finalize )
[docs] def run_training(self): """Run MACEWorkChain for each structure.""" split_datasets = SplitDataset(self.inputs.dataset) train_set = split_datasets["train_set"] validation_set = split_datasets["validation_set"] test_set = split_datasets["test_set"] self.out('global_splitted', split_datasets["global_splitted"]) self.report(f"Training set size: {len(train_set.get_list())}") self.report(f"Validation set size: {len(validation_set.get_list())}") self.report(f"Test set size: {len(test_set.get_list())}") inputs = self.exposed_inputs(MaceWorkChain, namespace="mace") inputs.train["training_set"] = train_set inputs.train["validation_set"] = validation_set inputs.train["test_set"] = test_set if 'checkpoints' in self.inputs: inputs['checkpoints'] = self.inputs.checkpoints inputs.train['restart'] = Bool(True) if 'checkpoints' in inputs: chkpts = list(dict(inputs.checkpoints).values()) for ii in range(self.inputs.num_potentials.value): if 'checkpoints' in self.inputs and ii < len(chkpts): inputs.train["checkpoints"] = chkpts[ii] inputs.train["index_pot"] = Int(ii) future = self.submit(MaceWorkChain, **inputs) self.to_context(mace_wc = append_(future)) pass
def finalize(self): results = {} for ii, calc in enumerate(self.ctx.mace_wc): results[f'mace_{ii}']={} for el in calc.outputs: results[f'mace_{ii}'][el] = calc.outputs[el] self.out('training', results)