Coverage for gpaw/test/fuzz.py: 80%

269 statements  

« prev     ^ index     » next       coverage.py v7.7.1, created at 2025-07-12 00:18 +0000

1from __future__ import annotations 

2 

3import argparse 

4import json 

5import os 

6import pickle 

7import random 

8import subprocess 

9import sys 

10from pathlib import Path 

11from time import time 

12from typing import Any, TypeVar, Callable, TYPE_CHECKING 

13 

14import numpy as np 

15from ase import Atoms 

16from ase.build import bulk 

17from ase.units import Bohr, Ha 

18from gpaw.calculator import GPAW as OldGPAW 

19from gpaw.mpi import world 

20from gpaw.new.ase_interface import GPAW as NewGPAW 

21 

22if TYPE_CHECKING: 

23 T = TypeVar('T') 

24 PickFunc = Callable[[list[T]], list[T]] 

25 

26 

27def main(args: str | list[str] = None) -> int: 

28 if isinstance(args, str): 

29 args = args.split() 

30 

31 parser = argparse.ArgumentParser() 

32 parser.add_argument('-r', '--repeat') 

33 parser.add_argument('-p', '--pbc') 

34 parser.add_argument('-v', '--vacuum') 

35 parser.add_argument('-M', '--magmoms') 

36 parser.add_argument('-k', '--kpts') 

37 parser.add_argument('-m', '--mode') 

38 parser.add_argument('-c', '--code') 

39 parser.add_argument('-n', '--ncores') 

40 parser.add_argument('-s', '--use-symmetry') 

41 parser.add_argument('-S', '--spin-polarized') 

42 parser.add_argument('-x', '--complex') 

43 parser.add_argument('-i', '--ignore-cache', action='store_true') 

44 parser.add_argument('-o', '--stdout', action='store_true') 

45 parser.add_argument('--pickle') 

46 parser.add_argument('--all', action='store_true') 

47 parser.add_argument('--fuzz', action='store_true') 

48 parser.add_argument('system', nargs='*') 

49 args = parser.parse_intermixed_args(args) 

50 

51 if args.pickle: 

52 pckl_file = Path(args.pickle) 

53 atoms, params, result_file = pickle.loads(pckl_file.read_bytes()) 

54 run2(atoms, params, result_file) 

55 return 0 

56 

57 many = args.all or args.fuzz 

58 

59 if many: 

60 system_names = args.system or list(systems) 

61 args.repeat = args.repeat or '1x1x1,2x1x1' 

62 args.vacuum = args.vacuum or '0.0,4.0' 

63 args.pbc = args.pbc or '0,1' 

64 args.mode = args.mode or 'pw,lcao,fd' 

65 args.code = args.code or 'new,old' 

66 args.ncores = args.ncores or '1,2,3,4' 

67 args.kpts = args.kpts or '2.0,3.0' 

68 args.use_symmetry = args.use_symmetry or '1,0' 

69 args.complex = args.complex or '0,1' 

70 else: 

71 system_names = args.system 

72 args.repeat = args.repeat or '1x1x1' 

73 args.vacuum = args.vacuum or '0.0' 

74 args.pbc = args.pbc or '0' 

75 args.mode = args.mode or 'pw' 

76 args.code = args.code or 'new' 

77 args.ncores = args.ncores or '1' 

78 args.kpts = args.kpts or '2.0' 

79 args.use_symmetry = args.use_symmetry or '1' 

80 args.complex = args.complex or '0' 

81 

82 if world.size > 1: 

83 args.ncores = str(world.size) 

84 

85 repeat_all = [[int(r) for r in rrr.split('x')] 

86 for rrr in args.repeat.split(',')] 

87 vacuum_all = [float(v) for v in args.vacuum.split(',')] 

88 pbc_all = [bool(int(p)) for p in args.pbc.split(',')] 

89 

90 magmoms = None if args.magmoms is None else [ 

91 float(m) for m in args.magmoms.split(',')] 

92 

93 mode_all = args.mode.split(',') 

94 kpts_all = [[int(k) for k in kpt.split(',')] if ',' in kpt else 

95 float(kpt) 

96 for kpt in args.kpts.split(',')] 

97 

98 code_all = args.code.split(',') 

99 ncores_all = [int(c) for c in args.ncores.split(',')] 

100 use_symmetry_all = [bool(int(s)) for s in args.use_symmetry.split(',')] 

101 complex_all = [bool(int(s)) for s in args.complex.split(',')] 

102 

103 # spinpol 

104 

105 if args.fuzz: 

106 def pick(choises): 

107 return [random.choice(choises)] 

108 else: 

109 def pick(choises): 

110 return choises 

111 

112 count = 0 

113 calculations = {} 

114 ok = True 

115 

116 while ok: 

117 for atoms, atag in create_systems(system_names, 

118 repeat_all, 

119 vacuum_all, 

120 pbc_all, 

121 magmoms, 

122 pick): 

123 for params, ptag in create_parameters(mode_all, 

124 kpts_all, 

125 pick): 

126 tag = atag + ' ' + ptag 

127 

128 for extra, xtag in create_extra_parameters(code_all, 

129 ncores_all, 

130 use_symmetry_all, 

131 complex_all, 

132 pick): 

133 params2 = {**params, **extra} 

134 result = run(atoms, 

135 params2, 

136 tag + ' ' + xtag, 

137 args.ignore_cache, 

138 args.stdout) 

139 ok = check(tag, result, calculations) 

140 count += 1 

141 if not ok: 

142 break 

143 if not ok: 

144 break 

145 if not ok: 

146 break 

147 

148 if not args.fuzz: 

149 break 

150 

151 return int(not ok) 

152 

153 

154def run(atoms: Atoms, 

155 params: dict[str, Any], 

156 tag: str, 

157 ignore_cache: bool = False, 

158 use_stdout: bool = False) -> dict[str, Any]: 

159 params = params.copy() 

160 name, things = tag.split(' ', 1) 

161 print(f'{name:3} {things}:', end='') 

162 tag = tag.replace(' ', '') 

163 folder = Path('fuzz') 

164 if not folder.is_dir(): 

165 folder.mkdir() 

166 result_file = folder / f'{tag}.json' 

167 if not use_stdout: 

168 params['txt'] = str(result_file.with_suffix('.txt')) 

169 if not result_file.is_file() or ignore_cache: 

170 print(' ...', end='', flush=True) 

171 ncores = params.pop('ncores') 

172 if ncores == world.size: 

173 result = run2(atoms, params, result_file) 

174 else: 

175 pckl_file = result_file.with_suffix('.pckl') 

176 pckl_file.write_bytes(pickle.dumps((atoms, params, result_file))) 

177 args = ['mpiexec', '-np', str(ncores), 

178 sys.executable, '-m', 'gpaw.test.fuzz', 

179 '--pickle', str(pckl_file)] 

180 extra = os.environ.get('GPAW_MPI_OPTIONS') 

181 if extra: 

182 args[1:1] = extra.split() 

183 subprocess.run(args, check=True, env=os.environ) 

184 result, _ = json.loads(result_file.read_text()) 

185 pckl_file.unlink() 

186 else: 

187 print(' ', end='') 

188 result, _ = json.loads(result_file.read_text()) 

189 print(f' {result["energy"]:14.6f} eV, {result["time"]:9.3f} s') 

190 return result 

191 

192 

193def run2(atoms: Atoms, 

194 params: dict[str, Any], 

195 result_file: Path) -> dict[str, Any]: 

196 params = params.copy() 

197 

198 code = params.pop('code') 

199 if code[0] == 'n': 

200 if params.pop('dtype', None) == complex: 

201 params['mode']['force_complex_dtype'] = True 

202 calc = NewGPAW(**params) 

203 else: 

204 calc = OldGPAW(**params) 

205 atoms.calc = calc 

206 

207 t1 = time() 

208 energy = atoms.get_potential_energy() 

209 try: 

210 forces = atoms.get_forces() 

211 except NotImplementedError: 

212 forces = None 

213 

214 t2 = time() 

215 

216 result = {'time': t2 - t1, 

217 'energy': energy, 

218 'forces': None if forces is None else forces.tolist()} 

219 

220 gpw_file = result_file.with_suffix('.gpw') 

221 calc.write(gpw_file, mode='all') 

222 

223 dft = NewGPAW(gpw_file).dft 

224 

225 energy2 = dft.results['energy'] * Ha 

226 assert abs(energy2 - energy) < 1e-13, (energy2, energy) 

227 

228 if forces is not None: 

229 forces2 = dft.results['forces'] * Ha / Bohr 

230 assert abs(forces2 - forces).max() < 1e-14 

231 

232 # ibz_index = atoms.calc.wfs.kd.bz2ibz_k[p.kpt] 

233 # eigs = atoms.calc.get_eigenvalues(ibz_index, p.spin) 

234 

235 if world.rank == 0: 

236 if 'dtype' in params: 

237 params['dtype'] = 'complex' 

238 result_file.write_text(json.dumps([result, params], indent=2)) 

239 

240 atoms.calc = None 

241 

242 return result 

243 

244 

245def check(tag: str, 

246 result: dict[str, Any], 

247 calculations: dict[str, dict[str, Any]]) -> bool: 

248 if tag not in calculations: 

249 calculations[tag] = result 

250 return True 

251 

252 result0 = calculations[tag] 

253 e0 = result0['energy'] 

254 f0 = result0['forces'] 

255 e = result['energy'] 

256 f = result['forces'] 

257 error = e - e0 

258 if abs(error) > 0.0005: 

259 print('Energy error:', e, e0, error) 

260 return False 

261 if f0 is None: 

262 if f is not None: 

263 calculations[tag]['forces'] = f 

264 return True 

265 if f is not None: 

266 error = abs(np.array(f) - f0).max() 

267 if error > 0.001: 

268 print('Force error:', error) 

269 return False 

270 return True 

271 

272 

273def create_systems(system_names: list[str], 

274 repeats: list[list[int]], 

275 vacuums: list[float], 

276 pbcs: list[bool], 

277 magmoms: list[float] | None, 

278 pick: PickFunc) -> tuple[Atoms, str]: 

279 for name in pick(system_names): 

280 atoms = systems[name] 

281 for repeat in pick(repeats): 

282 if any(not p and r > 1 for p, r in zip(atoms.pbc, repeat)): 

283 continue 

284 ratoms = atoms.repeat(repeat) 

285 for vacuum in pick(vacuums): 

286 if vacuum: 

287 vatoms = ratoms.copy() 

288 axes = [a for a, p in enumerate(atoms.pbc) if not p] 

289 if axes: 

290 vatoms.center(vacuum=vacuum, axis=axes) 

291 else: 

292 continue 

293 else: 

294 vatoms = ratoms 

295 for pbc in pick(pbcs): 

296 if pbc: 

297 if vatoms.pbc.all(): 

298 continue 

299 patoms = vatoms.copy() 

300 patoms.pbc = pbc 

301 else: 

302 patoms = vatoms 

303 

304 if magmoms is not None: 

305 patoms.set_initial_magnetic_moments( 

306 magmoms * (len(patoms) // len(magmoms))) 

307 

308 tag = (f'{name} ' 

309 f'-r{"x".join(str(r) for r in repeat)} ' 

310 f'-v{vacuum:.1f} ' 

311 f'-p{int(pbc)}') 

312 yield patoms, tag 

313 

314 

315def create_parameters(modes: list[str], 

316 kpts_all: list[float | list[int]], 

317 pick: PickFunc) -> tuple[dict[str, Any], str]: 

318 for mode in pick(modes): 

319 for kpt in pick(kpts_all): 

320 if isinstance(kpt, float): 

321 kpts = {'density': kpt} 

322 ktag = f'-k{kpt:.1f}' 

323 else: 

324 kpts = kpt 

325 ktag = f'-k{"x".join(str(k) for k in kpt)}' 

326 yield {'eigensolver': 'davidson' if mode == 'pw' else None, 

327 'mode': mode, 

328 'kpts': kpts}, f'-m{mode} {ktag}' 

329 

330 

331def create_extra_parameters(codes: list[str], 

332 ncores_all: list[int], 

333 symmetry_all: list[bool], 

334 complex_all: list[bool], 

335 pick: PickFunc) -> dict[str, Any]: 

336 for code in pick(codes): 

337 for ncores in pick(ncores_all): 

338 params = {'code': code, 

339 'ncores': ncores} 

340 for use_symm in pick(symmetry_all): 

341 if not use_symm: 

342 sparams = {**params, 'symmetry': 'off'} 

343 else: 

344 sparams = params 

345 for force_complex_dtype in pick(complex_all): 

346 if force_complex_dtype: 

347 sparams['dtype'] = complex 

348 yield (sparams, 

349 (f'-c{code} -n{ncores} -s{int(use_symm)} ' 

350 f'-x{int(force_complex_dtype)}')) 

351 

352 

353systems = {} 

354 

355 

356def system(func): 

357 systems[func.__name__] = func() 

358 return func 

359 

360 

361@system 

362def h(): 

363 atoms = Atoms('H', magmoms=[1]) 

364 atoms.center(vacuum=2.0) 

365 return atoms 

366 

367 

368@system 

369def h2(): 

370 atoms = Atoms('H2', [(0, 0, 0), (0, 0.75, 0)]) 

371 atoms.center(vacuum=2.0) 

372 return atoms 

373 

374 

375@system 

376def si(): 

377 atoms = bulk('Si', a=5.4) 

378 return atoms 

379 

380 

381@system 

382def fe(): 

383 atoms = bulk('Fe') 

384 atoms.set_initial_magnetic_moments([2.3]) 

385 return atoms 

386 

387 

388@system 

389def li(): 

390 L = 5.0 

391 atoms = Atoms('Li', cell=[L, L, 1.5], pbc=(0, 0, 1)) 

392 atoms.center() 

393 return atoms 

394 

395 

396if __name__ == '__main__': 

397 raise SystemExit(main())