Module 3 · Implementation
MAML + NequIP — Full Pipeline
We integrate Modules 1 and 2 into a complete meta-learning pipeline for protein force-field adaptation. The pipeline has four phases: meta-dataset construction, meta-training, task adaptation, and QM/MM integration.
3.1 Phase 0 — Meta-Dataset Construction
Define the task library \(\{\tau_i\}_{i=1}^N\). Each task is a distinct molecular microenvironment: a protein-fragment family, an amino acid in different protonation states, a cofactor-binding pocket, or a set of conformationally distinct scaffolds. For each task:
- Sample \(M_{\text{total}} \approx 200\) configurations from short classical MD at the target \(T\).
- Compute DFT energies and forces (ωB97X-D/6-311+G**) for all configurations.
- Split: 40 configs → support \(\mathcal{S}_i\); 160 configs → query \(\mathcal{Q}_i\).
3.2 Phase 1 — Meta-Training
# Meta-training loop for MAML-NequIP
import torch
from nequip.model import model_from_config
from learn2learn.algorithms import MAML
# Initialize NequIP model and wrap in MAML
base_model = model_from_config(config) # NequIP config.yaml
maml = MAML(base_model, lr=1e-3, first_order=False)
meta_opt = torch.optim.AdamW(maml.parameters(), lr=5e-4)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
meta_opt, T_max=500)
for epoch in range(500): # meta-epochs
meta_loss = 0.0
for task_batch in meta_dataloader: # 8 tasks per batch
for task in task_batch:
learner = maml.clone() # copy of θ for this task
support, query = task
# ── Inner loop (k = 5 adaptation steps) ──
for _ in range(5):
s_loss = energy_force_loss(learner, support)
learner.adapt(s_loss) # φᵢ ← φᵢ - α∇L(φᵢ; S)
# ── Outer loss on query set ──
meta_loss += energy_force_loss(learner, query)
meta_opt.zero_grad()
meta_loss.backward() # 2nd-order grad through inner loop
torch.nn.utils.clip_grad_norm_(maml.parameters(), 1.0)
meta_opt.step()
scheduler.step()
# Save meta-initialisation θ*
torch.save(maml.state_dict(), 'theta_star.pt')3.3 Phase 2 — Task Adaptation
Given a new protein system with \(n \ll N_{\text{standard}}\) QM reference configurations, adaptation proceeds in seconds:
# Load meta-initialisation
maml.load_state_dict(torch.load('theta_star.pt'))
learner = maml.clone()
# Adaptation support set (n = 20–50 DFT configs)
support_IFP = load_task_data('IFP_scaffold_support.xyz')
# Adapt: 1–10 gradient steps
for step in range(5):
loss = energy_force_loss(learner, support_IFP)
learner.adapt(loss)
print(f"Step {step}: loss = {loss.item():.4f}")
# φ_IFP is now the adapted NNP for the IFP scaffold
phi_IFP = learner # ready for RPMD3.4 Phase 3 — QM/MM Integration
The adapted NequIP potential serves as the MM region force field in a QM/MM calculation. The QM region (typically 10–50 atoms around the reactive site) is treated at DFT or DLPNO-CCSD(T) level. Interface: electrostatic embedding with link atoms at the QM/MM boundary.