#!/bin/bash
export OMP_NUM_THREADS=16
srun python /path-to-mace/mace/cli/run_train.py  \
    --name="mace_uhtc" \
    --foundation_model='/path-to-mpa0/mace-mpa-0-medium.model' \
    --train_file="train.xyz" \
    --valid_file="valid.xyz" \
    --test_file="test.xyz" \
    --model="MACE" \
    --loss="stress" \
    --energy_key="REF_energy" \
    --forces_key="REF_forces" \
    --stress_key="REF_stress" \
    --batch_size=10 \
    --valid_batch_size=10 \
    --seed=123 \
    --max_num_epochs=400 \
    --swa \
    --start_swa=200 \
    --swa_energy_weight=1000 \
    --swa_forces_weight=100 \
    --swa_stress_weight=1000 \
    --ema \
    --ema_decay=0.99 \
    --swa_lr=0.0001 \
    --amsgrad \
    --restart_latest \
    --device=cuda \
    --enable_cueq=true \
    --distributed \
