Coverage for gpaw/test/fuzz.py: 80%
269 statements
« prev ^ index » next coverage.py v7.7.1, created at 2025-07-12 00:18 +0000
« prev ^ index » next coverage.py v7.7.1, created at 2025-07-12 00:18 +0000
1from __future__ import annotations
3import argparse
4import json
5import os
6import pickle
7import random
8import subprocess
9import sys
10from pathlib import Path
11from time import time
12from typing import Any, TypeVar, Callable, TYPE_CHECKING
14import numpy as np
15from ase import Atoms
16from ase.build import bulk
17from ase.units import Bohr, Ha
18from gpaw.calculator import GPAW as OldGPAW
19from gpaw.mpi import world
20from gpaw.new.ase_interface import GPAW as NewGPAW
22if TYPE_CHECKING:
23 T = TypeVar('T')
24 PickFunc = Callable[[list[T]], list[T]]
27def main(args: str | list[str] = None) -> int:
28 if isinstance(args, str):
29 args = args.split()
31 parser = argparse.ArgumentParser()
32 parser.add_argument('-r', '--repeat')
33 parser.add_argument('-p', '--pbc')
34 parser.add_argument('-v', '--vacuum')
35 parser.add_argument('-M', '--magmoms')
36 parser.add_argument('-k', '--kpts')
37 parser.add_argument('-m', '--mode')
38 parser.add_argument('-c', '--code')
39 parser.add_argument('-n', '--ncores')
40 parser.add_argument('-s', '--use-symmetry')
41 parser.add_argument('-S', '--spin-polarized')
42 parser.add_argument('-x', '--complex')
43 parser.add_argument('-i', '--ignore-cache', action='store_true')
44 parser.add_argument('-o', '--stdout', action='store_true')
45 parser.add_argument('--pickle')
46 parser.add_argument('--all', action='store_true')
47 parser.add_argument('--fuzz', action='store_true')
48 parser.add_argument('system', nargs='*')
49 args = parser.parse_intermixed_args(args)
51 if args.pickle:
52 pckl_file = Path(args.pickle)
53 atoms, params, result_file = pickle.loads(pckl_file.read_bytes())
54 run2(atoms, params, result_file)
55 return 0
57 many = args.all or args.fuzz
59 if many:
60 system_names = args.system or list(systems)
61 args.repeat = args.repeat or '1x1x1,2x1x1'
62 args.vacuum = args.vacuum or '0.0,4.0'
63 args.pbc = args.pbc or '0,1'
64 args.mode = args.mode or 'pw,lcao,fd'
65 args.code = args.code or 'new,old'
66 args.ncores = args.ncores or '1,2,3,4'
67 args.kpts = args.kpts or '2.0,3.0'
68 args.use_symmetry = args.use_symmetry or '1,0'
69 args.complex = args.complex or '0,1'
70 else:
71 system_names = args.system
72 args.repeat = args.repeat or '1x1x1'
73 args.vacuum = args.vacuum or '0.0'
74 args.pbc = args.pbc or '0'
75 args.mode = args.mode or 'pw'
76 args.code = args.code or 'new'
77 args.ncores = args.ncores or '1'
78 args.kpts = args.kpts or '2.0'
79 args.use_symmetry = args.use_symmetry or '1'
80 args.complex = args.complex or '0'
82 if world.size > 1:
83 args.ncores = str(world.size)
85 repeat_all = [[int(r) for r in rrr.split('x')]
86 for rrr in args.repeat.split(',')]
87 vacuum_all = [float(v) for v in args.vacuum.split(',')]
88 pbc_all = [bool(int(p)) for p in args.pbc.split(',')]
90 magmoms = None if args.magmoms is None else [
91 float(m) for m in args.magmoms.split(',')]
93 mode_all = args.mode.split(',')
94 kpts_all = [[int(k) for k in kpt.split(',')] if ',' in kpt else
95 float(kpt)
96 for kpt in args.kpts.split(',')]
98 code_all = args.code.split(',')
99 ncores_all = [int(c) for c in args.ncores.split(',')]
100 use_symmetry_all = [bool(int(s)) for s in args.use_symmetry.split(',')]
101 complex_all = [bool(int(s)) for s in args.complex.split(',')]
103 # spinpol
105 if args.fuzz:
106 def pick(choises):
107 return [random.choice(choises)]
108 else:
109 def pick(choises):
110 return choises
112 count = 0
113 calculations = {}
114 ok = True
116 while ok:
117 for atoms, atag in create_systems(system_names,
118 repeat_all,
119 vacuum_all,
120 pbc_all,
121 magmoms,
122 pick):
123 for params, ptag in create_parameters(mode_all,
124 kpts_all,
125 pick):
126 tag = atag + ' ' + ptag
128 for extra, xtag in create_extra_parameters(code_all,
129 ncores_all,
130 use_symmetry_all,
131 complex_all,
132 pick):
133 params2 = {**params, **extra}
134 result = run(atoms,
135 params2,
136 tag + ' ' + xtag,
137 args.ignore_cache,
138 args.stdout)
139 ok = check(tag, result, calculations)
140 count += 1
141 if not ok:
142 break
143 if not ok:
144 break
145 if not ok:
146 break
148 if not args.fuzz:
149 break
151 return int(not ok)
154def run(atoms: Atoms,
155 params: dict[str, Any],
156 tag: str,
157 ignore_cache: bool = False,
158 use_stdout: bool = False) -> dict[str, Any]:
159 params = params.copy()
160 name, things = tag.split(' ', 1)
161 print(f'{name:3} {things}:', end='')
162 tag = tag.replace(' ', '')
163 folder = Path('fuzz')
164 if not folder.is_dir():
165 folder.mkdir()
166 result_file = folder / f'{tag}.json'
167 if not use_stdout:
168 params['txt'] = str(result_file.with_suffix('.txt'))
169 if not result_file.is_file() or ignore_cache:
170 print(' ...', end='', flush=True)
171 ncores = params.pop('ncores')
172 if ncores == world.size:
173 result = run2(atoms, params, result_file)
174 else:
175 pckl_file = result_file.with_suffix('.pckl')
176 pckl_file.write_bytes(pickle.dumps((atoms, params, result_file)))
177 args = ['mpiexec', '-np', str(ncores),
178 sys.executable, '-m', 'gpaw.test.fuzz',
179 '--pickle', str(pckl_file)]
180 extra = os.environ.get('GPAW_MPI_OPTIONS')
181 if extra:
182 args[1:1] = extra.split()
183 subprocess.run(args, check=True, env=os.environ)
184 result, _ = json.loads(result_file.read_text())
185 pckl_file.unlink()
186 else:
187 print(' ', end='')
188 result, _ = json.loads(result_file.read_text())
189 print(f' {result["energy"]:14.6f} eV, {result["time"]:9.3f} s')
190 return result
193def run2(atoms: Atoms,
194 params: dict[str, Any],
195 result_file: Path) -> dict[str, Any]:
196 params = params.copy()
198 code = params.pop('code')
199 if code[0] == 'n':
200 if params.pop('dtype', None) == complex:
201 params['mode']['force_complex_dtype'] = True
202 calc = NewGPAW(**params)
203 else:
204 calc = OldGPAW(**params)
205 atoms.calc = calc
207 t1 = time()
208 energy = atoms.get_potential_energy()
209 try:
210 forces = atoms.get_forces()
211 except NotImplementedError:
212 forces = None
214 t2 = time()
216 result = {'time': t2 - t1,
217 'energy': energy,
218 'forces': None if forces is None else forces.tolist()}
220 gpw_file = result_file.with_suffix('.gpw')
221 calc.write(gpw_file, mode='all')
223 dft = NewGPAW(gpw_file).dft
225 energy2 = dft.results['energy'] * Ha
226 assert abs(energy2 - energy) < 1e-13, (energy2, energy)
228 if forces is not None:
229 forces2 = dft.results['forces'] * Ha / Bohr
230 assert abs(forces2 - forces).max() < 1e-14
232 # ibz_index = atoms.calc.wfs.kd.bz2ibz_k[p.kpt]
233 # eigs = atoms.calc.get_eigenvalues(ibz_index, p.spin)
235 if world.rank == 0:
236 if 'dtype' in params:
237 params['dtype'] = 'complex'
238 result_file.write_text(json.dumps([result, params], indent=2))
240 atoms.calc = None
242 return result
245def check(tag: str,
246 result: dict[str, Any],
247 calculations: dict[str, dict[str, Any]]) -> bool:
248 if tag not in calculations:
249 calculations[tag] = result
250 return True
252 result0 = calculations[tag]
253 e0 = result0['energy']
254 f0 = result0['forces']
255 e = result['energy']
256 f = result['forces']
257 error = e - e0
258 if abs(error) > 0.0005:
259 print('Energy error:', e, e0, error)
260 return False
261 if f0 is None:
262 if f is not None:
263 calculations[tag]['forces'] = f
264 return True
265 if f is not None:
266 error = abs(np.array(f) - f0).max()
267 if error > 0.001:
268 print('Force error:', error)
269 return False
270 return True
273def create_systems(system_names: list[str],
274 repeats: list[list[int]],
275 vacuums: list[float],
276 pbcs: list[bool],
277 magmoms: list[float] | None,
278 pick: PickFunc) -> tuple[Atoms, str]:
279 for name in pick(system_names):
280 atoms = systems[name]
281 for repeat in pick(repeats):
282 if any(not p and r > 1 for p, r in zip(atoms.pbc, repeat)):
283 continue
284 ratoms = atoms.repeat(repeat)
285 for vacuum in pick(vacuums):
286 if vacuum:
287 vatoms = ratoms.copy()
288 axes = [a for a, p in enumerate(atoms.pbc) if not p]
289 if axes:
290 vatoms.center(vacuum=vacuum, axis=axes)
291 else:
292 continue
293 else:
294 vatoms = ratoms
295 for pbc in pick(pbcs):
296 if pbc:
297 if vatoms.pbc.all():
298 continue
299 patoms = vatoms.copy()
300 patoms.pbc = pbc
301 else:
302 patoms = vatoms
304 if magmoms is not None:
305 patoms.set_initial_magnetic_moments(
306 magmoms * (len(patoms) // len(magmoms)))
308 tag = (f'{name} '
309 f'-r{"x".join(str(r) for r in repeat)} '
310 f'-v{vacuum:.1f} '
311 f'-p{int(pbc)}')
312 yield patoms, tag
315def create_parameters(modes: list[str],
316 kpts_all: list[float | list[int]],
317 pick: PickFunc) -> tuple[dict[str, Any], str]:
318 for mode in pick(modes):
319 for kpt in pick(kpts_all):
320 if isinstance(kpt, float):
321 kpts = {'density': kpt}
322 ktag = f'-k{kpt:.1f}'
323 else:
324 kpts = kpt
325 ktag = f'-k{"x".join(str(k) for k in kpt)}'
326 yield {'eigensolver': 'davidson' if mode == 'pw' else None,
327 'mode': mode,
328 'kpts': kpts}, f'-m{mode} {ktag}'
331def create_extra_parameters(codes: list[str],
332 ncores_all: list[int],
333 symmetry_all: list[bool],
334 complex_all: list[bool],
335 pick: PickFunc) -> dict[str, Any]:
336 for code in pick(codes):
337 for ncores in pick(ncores_all):
338 params = {'code': code,
339 'ncores': ncores}
340 for use_symm in pick(symmetry_all):
341 if not use_symm:
342 sparams = {**params, 'symmetry': 'off'}
343 else:
344 sparams = params
345 for force_complex_dtype in pick(complex_all):
346 if force_complex_dtype:
347 sparams['dtype'] = complex
348 yield (sparams,
349 (f'-c{code} -n{ncores} -s{int(use_symm)} '
350 f'-x{int(force_complex_dtype)}'))
353systems = {}
356def system(func):
357 systems[func.__name__] = func()
358 return func
361@system
362def h():
363 atoms = Atoms('H', magmoms=[1])
364 atoms.center(vacuum=2.0)
365 return atoms
368@system
369def h2():
370 atoms = Atoms('H2', [(0, 0, 0), (0, 0.75, 0)])
371 atoms.center(vacuum=2.0)
372 return atoms
375@system
376def si():
377 atoms = bulk('Si', a=5.4)
378 return atoms
381@system
382def fe():
383 atoms = bulk('Fe')
384 atoms.set_initial_magnetic_moments([2.3])
385 return atoms
388@system
389def li():
390 L = 5.0
391 atoms = Atoms('Li', cell=[L, L, 1.5], pbc=(0, 0, 1))
392 atoms.center()
393 return atoms
396if __name__ == '__main__':
397 raise SystemExit(main())