Module 3 · Implementation

MAML + NequIP — Full Pipeline

We integrate Modules 1 and 2 into a complete meta-learning pipeline for protein force-field adaptation. The pipeline has four phases: meta-dataset construction, meta-training, task adaptation, and QM/MM integration.

3.1 Phase 0 — Meta-Dataset Construction

Define the task library \(\{\tau_i\}_{i=1}^N\). Each task is a distinct molecular microenvironment: a protein-fragment family, an amino acid in different protonation states, a cofactor-binding pocket, or a set of conformationally distinct scaffolds. For each task:

  1. Sample \(M_{\text{total}} \approx 200\) configurations from short classical MD at the target \(T\).
  2. Compute DFT energies and forces (ωB97X-D/6-311+G**) for all configurations.
  3. Split: 40 configs → support \(\mathcal{S}_i\); 160 configs → query \(\mathcal{Q}_i\).

3.2 Phase 1 — Meta-Training

# Meta-training loop for MAML-NequIP
import torch
from nequip.model import model_from_config
from learn2learn.algorithms import MAML

# Initialize NequIP model and wrap in MAML
base_model = model_from_config(config)           # NequIP config.yaml
maml = MAML(base_model, lr=1e-3, first_order=False)
meta_opt = torch.optim.AdamW(maml.parameters(), lr=5e-4)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
    meta_opt, T_max=500)

for epoch in range(500):                  # meta-epochs
    meta_loss = 0.0

    for task_batch in meta_dataloader:    # 8 tasks per batch
        for task in task_batch:
            learner = maml.clone()        # copy of θ for this task
            support, query = task

            # ── Inner loop (k = 5 adaptation steps) ──
            for _ in range(5):
                s_loss = energy_force_loss(learner, support)
                learner.adapt(s_loss)     # φᵢ ← φᵢ - α∇L(φᵢ; S)

            # ── Outer loss on query set ──
            meta_loss += energy_force_loss(learner, query)

    meta_opt.zero_grad()
    meta_loss.backward()                  # 2nd-order grad through inner loop
    torch.nn.utils.clip_grad_norm_(maml.parameters(), 1.0)
    meta_opt.step()
    scheduler.step()

# Save meta-initialisation θ*
torch.save(maml.state_dict(), 'theta_star.pt')

3.3 Phase 2 — Task Adaptation

Given a new protein system with \(n \ll N_{\text{standard}}\) QM reference configurations, adaptation proceeds in seconds:

# Load meta-initialisation
maml.load_state_dict(torch.load('theta_star.pt'))
learner = maml.clone()

# Adaptation support set (n = 20–50 DFT configs)
support_IFP = load_task_data('IFP_scaffold_support.xyz')

# Adapt: 1–10 gradient steps
for step in range(5):
    loss = energy_force_loss(learner, support_IFP)
    learner.adapt(loss)
    print(f"Step {step}: loss = {loss.item():.4f}")

# φ_IFP is now the adapted NNP for the IFP scaffold
phi_IFP = learner                         # ready for RPMD

3.4 Phase 3 — QM/MM Integration

The adapted NequIP potential serves as the MM region force field in a QM/MM calculation. The QM region (typically 10–50 atoms around the reactive site) is treated at DFT or DLPNO-CCSD(T) level. Interface: electrostatic embedding with link atoms at the QM/MM boundary.

QM region10–50 atomsωB97X-D/6-311+G**MM regionprotein scaffoldMAML-NequIP φ*link atomselectrostaticembeddingSolventTIP3P / implicit
QM/MM partitioning with MAML-NequIP as the MM force field. The QM region covers the reactive site; the adapted NequIP handles the protein scaffold at near-DFT accuracy.