from aiida.engine import WorkChain, ToContext, append_, calcfunction, while_
from aiida.orm import Float, Dict, List, Str, SinglefileData, StructureData
from aiida.plugins import WorkflowFactory, DataFactory
from aiida.common import AttributeDict
import os
from aiida_lammps.data.potential import LammpsPotentialData
from pathlib import Path
import tempfile
import random # to change seed for each retry
from aiida_quantumespresso.workflows.protocols.utils import recursive_merge
from aiida_trains_pot.utils.lammps_pair_coeffs import get_dftd2_pair_coeffs, get_mace_pair_coeff
LammpsWorkChain = WorkflowFactory('lammps.base')
PESData = DataFactory('pesdata')
def generate_potential(potential, pair_style) -> LammpsPotentialData:
"""
Generate the potential to be used in the calculation.
Takes a potential form OpenKIM and stores it as a LammpsPotentialData object.
:return: potential to do the calculation
:rtype: LammpsPotentialData
"""
potential_parameters = {
"species": [],
"atom_style": "atomic",
"units": "metal",
"extra_tags": {},
}
# Assuming you have a trained MACE model
with tempfile.NamedTemporaryFile(delete=False) as tmp_file:
with potential.open(mode='rb') as potential_handle:
potential_content = potential_handle.read()
tmp_file.write(potential_content)
tmp_file_path = tmp_file.name
potential = LammpsPotentialData.get_or_create(
source = Path(tmp_file_path),
pair_style = pair_style,
**potential_parameters,
)
os.remove(tmp_file_path)
return potential
###################################################################
## DEFAULT VALUES ##
###################################################################
DEFAULT_potential_pair_style = Str("mace no_domain_decomposition")
DEFAULT_settings = Dict({"store_restart": True,},)
DEFAULT_parameters = Dict({"structure":{
"atom_style" : "atomic",
"atom_modify": "map yes",
"dimension" : "3",
"boundary" : "p p p",
},
"potential":{
"potential_style_options": "mace no_domain_decomposition",
},
"control":{
"timestep": 0.001,
"newton": 'on',
"units": 'metal',
},
"thermo":{
"printing_rate": 100,
"thermo_printing": {
"step": True,
"time": True,
"pe": True,
"ke": True,
"etotal": True,
"press": True,
"pxx": True,
"pyy": True,
"pzz": True,
"temp": True,
},
},
"restart":
{"print_final": True,
},
"md": {},
"dump": {},
})
###################################################################
[docs]
class ExplorationWorkChain(WorkChain):
"""A workchain to loop over structures and submit LammpsWorkChain with retries."""
[docs]
@classmethod
def define(cls, spec):
super().define(spec)
spec.input('params_list', valid_type=List, help='List of parameters for md')
spec.input('parameters', valid_type=Dict, help='Global parameters for lammps')
spec.input('potential_lammps', valid_type=SinglefileData, required=False, help='One of the potential for MD')
spec.input('potential_pair_style', valid_type=Str, default=lambda:DEFAULT_potential_pair_style, required=False, help=f"General potential pair style. Default: {DEFAULT_potential_pair_style}")
spec.input('sampling_time', valid_type=Float, help='Correlation time for frame extraction')
spec.input('protocol', valid_type=Str, help='Protocol for the calculation', required=False)
spec.input('lammps_input_structures', valid_type=PESData, help='Input structures for lammps')
spec.expose_inputs(LammpsWorkChain, namespace="md", exclude=('lammps.structure', 'lammps.potential', 'lammps.parameters'), namespace_options={'validator': None})
spec.output_namespace("md", dynamic=True, help="Exploration outputs")
spec.outline(
cls.run_md,
cls.finalize_md,
)
spec.outline(
cls.run_md,
while_(cls.not_converged)(
cls.run_restart,
),
cls.finalize_md,
)
[docs]
def run_md(self):
"""Run MD simulations for each structure and MD parameter set, with retries on failure."""
potential = self.inputs.potential_lammps
self.ctx.rerun_wc = []
self.ctx.rerun_wc_old = []
self.ctx.last_wc = []
self.ctx.dict_wc = {}
self.ctx.iteration = 0
# Loop over structures
for structure in self.inputs.lammps_input_structures.get_ase_list():
inputs = self.exposed_inputs(LammpsWorkChain, namespace="md")
inputs.lammps.structure = StructureData(ase=structure)
inputs.lammps.potential = generate_potential(potential, str(self.inputs.potential_pair_style.value))
generate_pair_coeff = True
if "potential" in self.inputs.parameters:
if "pair_coeff_list" in self.inputs.parameters["potential"]:
generate_pair_coeff = False
# Pair coefficients for MACE potential without hybrid/overlay is always generated, if needed it is overwritten
if generate_pair_coeff:
pair_coeffs = [get_mace_pair_coeff(inputs.lammps.structure, hybrid=False)]
params_list = self.inputs.params_list.get_list()
input_parameters = self.inputs.parameters.get_dict()
if self.inputs.protocol is not None:
if self.inputs.protocol == 'vdw_d2':
if 'potential' in self.inputs.parameters:
if 'potential_style_options' not in self.inputs.parameters['potential']:
input_parameters['potential']['potential_style_options'] = 'mace no_domain_decomposition momb 20.0 0.75 20.0'
else:
input_parameters['potential'] = {'potential_style_options': 'mace no_domain_decomposition momb 20.0 0.75 20.0'}
if generate_pair_coeff:
# Generate DFT-D2 pair coefficients, it overwrites the MACE pair_coeff generated above
pair_coeffs = get_dftd2_pair_coeffs(inputs.lammps.structure)
pair_coeffs.append(get_mace_pair_coeff(inputs.lammps.structure, hybrid=True))
input_parameters['potential']['pair_coeff_list'] = pair_coeffs
parameters = recursive_merge(DEFAULT_parameters.get_dict(), input_parameters)
inputs.lammps.settings = recursive_merge(DEFAULT_settings.get_dict(), self.inputs.md.lammps.settings.get_dict())
# if 'dump' not in parameters:
# parameters['dump'] = {}
parameters['dump']['dump_rate'] = int(self.inputs.sampling_time / parameters['control']['timestep'])
# Loop over the MD parameter sets
for params_md in params_list:
if not any(inputs.lammps.structure.pbc):
params_md["integration"]["style"] = "nvt"
constraint = params_md["integration"]["constraints"]
# Map dimensions to constraints
axes = ["x", "y", "z"]
# Remove constraints for non-periodic directions
for idx, axis in enumerate(axes):
if not inputs.lammps.structure.pbc[idx]:
constraint.pop(axis, None) # Avoid KeyError if axis doesn't exist
params_md["integration"]["constraints"] = constraint
parameters['md'] = dict(params_md)
inputs.lammps.parameters = Dict(parameters)
future = self.submit(LammpsWorkChain, **inputs)
self.to_context(md_wc=append_(future))
self.ctx.dict_wc[f'{self.ctx.iteration}'] = self.ctx.iteration
self.ctx.last_wc.append(self.ctx.iteration)
self.ctx.iteration += 1
def run_restart(self):
self.ctx.last_wc= []
for ii, calc in enumerate(self.ctx.md_wc):
if (ii in self.ctx.rerun_wc):
incoming = calc.base.links.get_incoming().nested()
# Build the inputs dictionary
inputs = self.exposed_inputs(LammpsWorkChain, namespace="md")
for key, node in incoming.items():
if key == 'lammps':
inputs[key].update(node) # Merge nested inputs
future = self.submit(LammpsWorkChain, **inputs)
self.to_context(md_wc=append_(future))
self.ctx.dict_wc[f'{self.ctx.iteration}'] = self.ctx.dict_wc[f'{ii}']
self.ctx.last_wc.append(self.ctx.iteration)
self.ctx.iteration += 1
[docs]
def not_converged(self):
"""Check if any calculation did not end successfully and requires a restart."""
# Update the old list of reruns and prepare a new one
self.ctx.rerun_wc_old.extend(self.ctx.rerun_wc)
self.ctx.rerun_wc = []
for ii, calc in enumerate(self.ctx.md_wc):
if (ii in self.ctx.last_wc) and calc.exit_status != 0:
# Count how many times the current calculation has been retried
retry_count = sum(1 for value in self.ctx.dict_wc.values() if value == self.ctx.dict_wc[f'{ii}'])
# Check if the calculation failed and has been retried less than 5 times
if retry_count < 3 and (ii not in self.ctx.rerun_wc_old):
self.ctx.rerun_wc.append(ii)
return len(self.ctx.rerun_wc) > 0
[docs]
def finalize_md(self):
"""Collect the results from the completed LAMMPS calculations."""
md_out = {}
for ii, calc in enumerate(self.ctx.md_wc):
if calc.exit_status == 0:
self.report(f'md_{ii} exit0')
md_out[f'md_{ii}'] = {el: calc.outputs[el] for el in calc.outputs}
self.out('md', md_out)