Coverage for gpaw/test/response/test_chiks.py: 97%

230 statements  

« prev     ^ index     » next       coverage.py v7.7.1, created at 2025-07-08 00:17 +0000

1"""Test functionality to compute the four-component susceptibility tensor for 

2the Kohn-Sham system.""" 

3 

4from itertools import product, combinations 

5 

6import numpy as np 

7import pytest 

8from gpaw import GPAW 

9from gpaw.mpi import world 

10from gpaw.response import ResponseContext, ResponseGroundStateAdapter 

11from gpaw.response.frequencies import (ComplexFrequencyDescriptor, 

12 FrequencyDescriptor) 

13from gpaw.response.chiks import ChiKSCalculator, SelfEnhancementCalculator 

14from gpaw.response.chi0 import Chi0Calculator 

15from gpaw.response.pair_functions import (get_inverted_pw_mapping, 

16 get_pw_coordinates) 

17from gpaw.test.gpwfile import response_band_cutoff 

18 

19# ---------- chiks parametrization ---------- # 

20 

21 

22def generate_system_s(spincomponents=['00', '+-']): 

23 # Compute chiks for different materials and spin components 

24 system_s = [ # wfs, spincomponent 

25 ('fancy_si_pw', '00'), 

26 ('al_pw', '00'), 

27 ('fe_pw', '00'), 

28 ('fe_pw', '+-'), 

29 ('co_pw', '00'), 

30 ('co_pw', '+-'), 

31 ] 

32 

33 # Filter spincomponents 

34 system_s = [system for system in system_s if system[1] in spincomponents] 

35 

36 return system_s 

37 

38 

39def generate_qrel_q(): 

40 # Fractional q-vectors on a path towards a reciprocal lattice vector 

41 qrel_q = np.array([0., 0.25, 0.5]) 

42 

43 return qrel_q 

44 

45 

46def get_q_c(wfs, qrel): 

47 if wfs in ['fancy_si_pw', 'al_pw']: 

48 # Generate points on the G-X path 

49 q_c = qrel * np.array([1., 0., 1.]) 

50 elif wfs == 'fe_pw': 

51 # Generate points on the G-N path 

52 q_c = qrel * np.array([0., 0., 1.]) 

53 elif wfs == 'co_pw': 

54 # Generate points on the G-M path 

55 q_c = qrel * np.array([1., 0., 0.]) 

56 else: 

57 raise ValueError('Invalid wfs', wfs) 

58 

59 return q_c 

60 

61 

62def get_tolerances(system, qrel): 

63 # Define tolerance for each test system 

64 wfs, spincomponent = system 

65 identifier = wfs + '_' + spincomponent 

66 

67 # Si and Fe the density-density response has perfect symmetry 

68 atols = { 

69 'fancy_si_pw_00': 1e-8, 

70 'fe_pw_00': 1e-8, 

71 } 

72 

73 # For the rest, we need to adjust the absolute tolerances. In general 

74 # it should be possible to lower these tolerances when increasing the 

75 # number of bands. 

76 

77 # For Al, the symmetries are not perfectly conserved, but worst for the 

78 # q-point q_X 

79 if qrel == 0.0: 

80 al_atol = 1e-6 

81 elif qrel == 0.25: 

82 al_atol = 5e-5 

83 elif qrel == 0.5: 

84 al_atol = 2e-4 

85 atols['al_pw_00'] = al_atol 

86 

87 # For Fe, the symmetries are not perfectly conserved for the 

88 # transverse magnetic response 

89 if qrel == 0.0: 

90 fet_atol = 3e-3 

91 elif qrel == 0.25: 

92 fet_atol = 16e-3 

93 elif qrel == 0.5: 

94 fet_atol = 5e-4 

95 atols['fe_pw_+-'] = fet_atol 

96 

97 # For the density-density reponse in Co, the symmetries are not perfectly 

98 # conserved for any of the q-points, but quite well conserved for q = 0 

99 if qrel == 0.0: 

100 co_atol = 5e-5 

101 elif qrel == 0.25: 

102 co_atol = 5e-3 

103 elif qrel == 0.5: 

104 co_atol = 1e-3 

105 atols['co_pw_00'] = co_atol 

106 

107 # For the transverse magnetic response in Co, the symmetries are not 

108 # perfectly conserved for any of the q-points, but again quite well 

109 # conserved for q = 0 

110 if qrel == 0.0: 

111 cot_atol = 5e-4 

112 elif qrel == 0.25: 

113 cot_atol = 1e-3 

114 elif qrel == 0.5: 

115 cot_atol = 1e-3 

116 atols['co_pw_+-'] = cot_atol 

117 

118 if identifier not in atols.keys(): 

119 raise ValueError(system, qrel) 

120 

121 atol = atols[identifier] 

122 rtol = 1e-5 

123 

124 return atol, rtol 

125 

126 

127def generate_gc_g(): 

128 # Compute chiks both on a gamma-centered and a q-centered pw grid 

129 gc_g = [True, False] 

130 

131 return gc_g 

132 

133 

134def generate_nblocks_n(): 

135 nblocks_n = [1] 

136 if world.size % 2 == 0: 

137 nblocks_n.append(2) 

138 if world.size % 4 == 0: 

139 nblocks_n.append(4) 

140 

141 return nblocks_n 

142 

143 

144# ---------- Actual tests ---------- # 

145 

146 

147@pytest.mark.response 

148@pytest.mark.kspair 

149@pytest.mark.parametrize( 

150 'system,qrel,gammacentered', 

151 product(generate_system_s(), generate_qrel_q(), generate_gc_g())) 

152def test_chiks(in_tmp_dir, gpw_files, system, qrel, gammacentered): 

153 r"""Test the internals of the ChiKSCalculator. 

154 

155 In particular, we test that the susceptibility does not change due to the 

156 details in the internal calculator, such as varrying block distribution, 

157 band summation scheme, reducing the k-point integral using symmetries or 

158 basing the ground state adapter on a dynamic (and distributed) GPAW 

159 calculator. 

160 

161 Furthermore, we test the symmetries of the calculated susceptibilities. 

162 """ 

163 

164 # Part 1: Set up ChiKSTestingFactory 

165 wfs, spincomponent = system 

166 atol, rtol = get_tolerances(system, qrel) 

167 q_c = get_q_c(wfs, qrel) 

168 

169 ecut = 50 

170 # Test vanishing and finite real and imaginary frequencies 

171 frequencies = np.array([0., 0.05, 0.1, 0.2]) 

172 

173 # We add a small (1e-6j) imaginary part to avoid risky floating point 

174 # operations that may cause NaNs or divide-by-zero. 

175 complex_frequencies = list(frequencies + 1e-6j) + list(frequencies + 0.1j) 

176 zd = ComplexFrequencyDescriptor.from_array(complex_frequencies) 

177 

178 # Part 2: Check toggling of calculation parameters 

179 # Note: None of these should change the actual results. 

180 disable_syms_s = [True, False] 

181 

182 nblocks_n = generate_nblocks_n() 

183 nn = len(nblocks_n) 

184 

185 bandsummation_b = ['double', 'pairwise'] 

186 distribution_d = ['GZg', 'ZgG'] 

187 

188 # Symmetry independent tolerances (relating to chiks distribution) 

189 dist_atol = 1e-8 

190 dist_rtol = 1e-6 

191 

192 # Part 3: Check reciprocity and inversion symmetry 

193 

194 # ---------- Script ---------- # 

195 

196 # Part 1: Set up ChiKSTestingFactory 

197 calc = GPAW(gpw_files[wfs], parallel=dict(domain=1)) 

198 nbands = response_band_cutoff[wfs] 

199 

200 chiks_testing_factory = ChiKSTestingFactory(calc, 

201 spincomponent, q_c, zd, 

202 nbands, ecut, gammacentered) 

203 

204 # Part 2: Check toggling of calculation parameters 

205 

206 # Check symmetry toggle and cross-tabulate with nblocks and bandsummation 

207 chiks_testing_factory.check_parameter_self_consistency( 

208 parameter='disable_syms', values=disable_syms_s, 

209 atol=atol, rtol=rtol, 

210 cross_tabulation=dict(nblocks=nblocks_n, 

211 bandsummation=bandsummation_b)) 

212 

213 # Check nblocks and cross-tabulate with disable_syms and bandsummation 

214 for n1, n2 in combinations(range(nn), 2): 

215 chiks_testing_factory.check_parameter_self_consistency( 

216 parameter='nblocks', values=[nblocks_n[n1], nblocks_n[n2]], 

217 atol=dist_atol, rtol=dist_rtol, 

218 cross_tabulation=dict(disable_syms=disable_syms_s, 

219 bandsummation=bandsummation_b)) 

220 

221 # Check bandsummation and cross-tabulate with disable_syms and nblocks 

222 chiks_testing_factory.check_parameter_self_consistency( 

223 parameter='bandsummation', values=bandsummation_b, 

224 atol=atol, rtol=rtol, 

225 cross_tabulation=dict(disable_syms=disable_syms_s, 

226 nblocks=nblocks_n)) 

227 

228 # Check internal distribution and cross-tabulate with nblocks 

229 chiks_testing_factory.check_parameter_self_consistency( 

230 parameter='distribution', values=distribution_d, 

231 atol=dist_atol, rtol=dist_rtol, 

232 cross_tabulation=dict(nblocks=nblocks_n)) 

233 

234 # Part 3: Check reciprocity and inversion symmetry 

235 

236 # Cross-tabulate disable_syms, nblocks and bandsummation 

237 chiks_testing_factory.check_reciprocity_and_inversion_symmetry( 

238 atol=atol, rtol=rtol, 

239 cross_tabulation=dict(disable_syms=disable_syms_s, 

240 nblocks=nblocks_n, 

241 bandsummation=bandsummation_b)) 

242 

243 # Cross-tabulate distribution and nblocks 

244 chiks_testing_factory.check_reciprocity_and_inversion_symmetry( 

245 atol=atol, rtol=rtol, 

246 cross_tabulation=dict(distribution=distribution_d, 

247 nblocks=nblocks_n)) 

248 

249 # Make it possible to check timings for the test 

250 chiks_testing_factory.context.write_timer() 

251 

252 

253@pytest.mark.response 

254@pytest.mark.kspair 

255@pytest.mark.parametrize( 

256 'system,qrel', 

257 product(generate_system_s(spincomponents=['00']), generate_qrel_q())) 

258def test_chiks_vs_chi0(in_tmp_dir, gpw_files, system, qrel): 

259 """Test that the ChiKSCalculator is able to reproduce the Chi0Body. 

260 

261 We use only the default calculation parameter setup for the ChiKSCalculator 

262 and leave parameter cross-validation to the test above.""" 

263 

264 # ---------- Inputs ---------- # 

265 

266 # Part 1: chiks calculation 

267 wfs, spincomponent = system 

268 q_c = get_q_c(wfs, qrel) 

269 

270 ecut = 50 

271 # Test vanishing and finite real and imaginary frequencies 

272 frequencies = np.array([0., 0.05, 0.1, 0.2]) 

273 eta = 0.15 

274 complex_frequencies = frequencies + 1.j * eta 

275 

276 # Part 2: chi0 calculation 

277 

278 # Part 3: Check chiks vs. chi0 

279 

280 # ---------- Script ---------- # 

281 

282 # Part 1: chiks calculation 

283 

284 # Initialize ground state adapter 

285 gs = ResponseGroundStateAdapter.from_gpw_file(gpw_files[wfs]) 

286 nbands = response_band_cutoff[wfs] 

287 

288 # Set up frequency descriptors 

289 wd = FrequencyDescriptor.from_array_or_dict(frequencies) 

290 zd = ComplexFrequencyDescriptor.from_array(complex_frequencies) 

291 

292 # Calculate chiks 

293 chiks_calc = ChiKSCalculator(gs, ecut=ecut, nbands=nbands) 

294 chiks = chiks_calc.calculate(spincomponent, q_c, zd) 

295 chiks = chiks.copy_with_global_frequency_distribution() 

296 chiks_calc.context.write_timer() 

297 

298 # Part 2: chi0 calculation 

299 chi0_calc = Chi0Calculator(gs, wd=wd, eta=eta, 

300 ecut=ecut, nbands=nbands, 

301 hilbert=False, intraband=False) 

302 chi0 = chi0_calc.calculate(q_c) 

303 chi0_wGG = chi0.body.get_distributed_frequencies_array() 

304 chi0_calc.context.write_timer() 

305 

306 # Part 3: Check chiks vs. chi0 

307 assert chiks.array == pytest.approx(chi0_wGG, rel=1e-3, abs=1e-5) 

308 

309 

310@pytest.mark.response 

311@pytest.mark.kspair 

312@pytest.mark.parametrize( 

313 'system,qrel,gammacentered', 

314 product(generate_system_s(spincomponents=['+-']), 

315 generate_qrel_q(), generate_gc_g())) 

316def test_xi(gpw_files, system, qrel, gammacentered): 

317 """Test that calculated self-enhancement function does not change 

318 when varrying internal calculator parameters.""" 

319 # ---------- Inputs ---------- # 

320 wfs, spincomponent = system 

321 nbands = response_band_cutoff[wfs] 

322 atol, rtol = get_tolerances(system, qrel) 

323 q_c = get_q_c(wfs, qrel) 

324 

325 complex_frequencies = np.array([0., 0.05, 0.1, 0.2]) + 0.1j 

326 zd = ComplexFrequencyDescriptor.from_array(complex_frequencies) 

327 

328 ecut = 50 

329 rshelmax = 0 

330 

331 if world.size > 1: 

332 nblocks = 2 

333 else: 

334 nblocks = 1 

335 

336 fixed_kwargs = dict(nbands=nbands, 

337 ecut=ecut, 

338 gammacentered=gammacentered, 

339 rshelmax=rshelmax, 

340 nblocks=nblocks) 

341 

342 # Parameters to cross-tabulate 

343 qsymmetry_s = [True, False] 

344 bandsummation_b = ['double', 'pairwise'] 

345 

346 # ---------- Script ---------- # 

347 

348 calc = GPAW(gpw_files[wfs], parallel=dict(domain=1)) 

349 gs = ResponseGroundStateAdapter(calc) 

350 

351 xi_mzGG = [] 

352 for qsymmetry in qsymmetry_s: 

353 for bandsummation in bandsummation_b: 

354 xi_calc = SelfEnhancementCalculator( 

355 gs, 

356 qsymmetry=qsymmetry, 

357 bandsummation=bandsummation, 

358 **fixed_kwargs) 

359 xi = xi_calc.calculate(spincomponent, q_c, zd) 

360 xi_mzGG.append(xi.array) 

361 xi_mzGG = np.array(xi_mzGG) 

362 

363 # Test versus average 

364 avgxi_zGG = np.average(xi_mzGG, axis=0) 

365 for xi_zGG in xi_mzGG: 

366 assert xi_zGG == pytest.approx(avgxi_zGG, rel=rtol, abs=atol) 

367 

368 

369# ---------- Test functionality ---------- # 

370 

371 

372class ChiKSTestingFactory: 

373 """Factory to calculate and cache chiks objects.""" 

374 

375 def __init__(self, calc, 

376 spincomponent, q_c, zd, 

377 nbands, ecut, gammacentered): 

378 self.gs = GSAdapterWithPAWCache(calc) 

379 self.context = ResponseContext() 

380 self.spincomponent = spincomponent 

381 self.q_c = q_c 

382 self.zd = zd 

383 self.nbands = nbands 

384 self.ecut = ecut 

385 self.gammacentered = gammacentered 

386 

387 self.cached_chiks = {} 

388 

389 def __call__(self, 

390 qsign: int = 1, 

391 distribution: str = 'GZg', 

392 disable_syms: bool = False, 

393 bandsummation: str = 'pairwise', 

394 nblocks: int = 1): 

395 # Compile a string of the calculation parameters for cache look-up 

396 cache_string = f'{qsign},{distribution},{disable_syms}' 

397 cache_string += f',{bandsummation},{nblocks}' 

398 

399 if cache_string in self.cached_chiks: 

400 return self.cached_chiks[cache_string] 

401 

402 chiks_calc = ChiKSCalculator( 

403 self.gs, context=self.context, 

404 ecut=self.ecut, nbands=self.nbands, 

405 gammacentered=self.gammacentered, 

406 qsymmetry=not disable_syms, 

407 bandsummation=bandsummation, 

408 nblocks=nblocks) 

409 

410 # Do a manual calculation of chiks 

411 chiks = chiks_calc._calculate(*chiks_calc._set_up_internals( 

412 self.spincomponent, qsign * self.q_c, self.zd, 

413 distribution=distribution)) 

414 

415 chiks = chiks.copy_with_global_frequency_distribution() 

416 self.cached_chiks[cache_string] = chiks 

417 

418 return chiks 

419 

420 def check_parameter_self_consistency(self, 

421 parameter: str, values: list, 

422 atol: float, 

423 rtol: float, 

424 cross_tabulation: dict): 

425 assert len(values) == 2 

426 for kwargs in self.generate_cross_tabulated_kwargs(cross_tabulation): 

427 kwargs[parameter] = values[0] 

428 chiks1 = self(**kwargs) 

429 kwargs[parameter] = values[1] 

430 chiks2 = self(**kwargs) 

431 compare_pw_bases(chiks1, chiks2) 

432 assert chiks2.array == pytest.approx( 

433 chiks1.array, rel=rtol, abs=atol), f'{kwargs}' 

434 

435 def check_reciprocity_and_inversion_symmetry(self, 

436 atol: float, 

437 rtol: float, 

438 cross_tabulation: dict): 

439 for kwargs in self.generate_cross_tabulated_kwargs(cross_tabulation): 

440 # Calculate chiks in q and -q 

441 chiks1 = self(**kwargs) 

442 if np.allclose(self.q_c, 0.): 

443 chiks2 = chiks1 

444 else: 

445 chiks2 = self(qsign=-1, **kwargs) 

446 check_reciprocity_and_inversion_symmetry(chiks1, chiks2, 

447 atol=atol, rtol=rtol) 

448 

449 @staticmethod 

450 def generate_cross_tabulated_kwargs(cross_tabulation: dict): 

451 # Set up cross tabulation of calculation parameters 

452 cross_tabulator = product(*[[(key, value) 

453 for value in cross_tabulation[key]] 

454 for key in cross_tabulation]) 

455 for cross_tabulated_parameters in cross_tabulator: 

456 yield {key: value for key, value in cross_tabulated_parameters} 

457 

458 

459class GSAdapterWithPAWCache(ResponseGroundStateAdapter): 

460 """Add a PAW correction cache to the ground state adapter. 

461 

462 WARNING: Use with care! The cache is only valid, when the plane-wave 

463 representations are identical and the functional f[n](r) is not changed. 

464 """ 

465 

466 def __init__(self, calc): 

467 super().__init__(calc) 

468 

469 self._cached_corrections = [] 

470 self._cached_parameters = [] 

471 

472 def matrix_element_paw_corrections(self, qpd, rshe_a): 

473 """Overwrite method with a cached version.""" 

474 cache_index = self._cache_lookup(qpd) 

475 if cache_index is not None: 

476 return self._cached_corrections[cache_index] 

477 

478 return self._calculate_correction(qpd, rshe_a) 

479 

480 def _calculate_correction(self, qpd, rshe_a): 

481 correction = super().matrix_element_paw_corrections(qpd, rshe_a) 

482 

483 self._cached_corrections.append(correction) 

484 self._cached_parameters.append((qpd.q_c, qpd.ecut, qpd.gammacentered)) 

485 

486 return correction 

487 

488 def _cache_lookup(self, qpd): 

489 for i, (q_c, ecut, 

490 gammacentered) in enumerate(self._cached_parameters): 

491 if np.allclose(qpd.q_c, q_c) and abs(qpd.ecut - ecut) < 1e-8\ 

492 and qpd.gammacentered == gammacentered: 

493 # Cache hit! 

494 return i 

495 

496 

497def compare_pw_bases(chiks1, chiks2): 

498 """Compare the plane-wave representations of two calculated chiks.""" 

499 G1_Gc = get_pw_coordinates(chiks1.qpd) 

500 G2_Gc = get_pw_coordinates(chiks2.qpd) 

501 assert G1_Gc.shape == G2_Gc.shape 

502 assert np.allclose(G1_Gc - G2_Gc, 0.) 

503 

504 

505def check_reciprocity_and_inversion_symmetry(chiks1, chiks2, *, atol, rtol): 

506 """Check the susceptibilities for reciprocity and inversion symmetry 

507 

508 In particular, we test the reciprocity relation (valid both for μν=00 and 

509 μν=+-), 

510 

511 χ_(KS,GG')^(μν)(q, ω) = χ_(KS,-G'-G)^(μν)(-q, ω), 

512 

513 the inversion symmetry relation, 

514 

515 χ_(KS,GG')^(μν)(q, ω) = χ_(KS,-G-G')^(μν)(-q, ω), 

516 

517 and the combination of the two, 

518 

519 χ_(KS,GG')^(μν)(q, ω) = χ_(KS,G'G)^(μν)(q, ω), 

520 

521 for a real life periodic systems with an inversion center. 

522 

523 Unfortunately, there will always be random noise in the wave functions, 

524 such that these symmetries cannot be fulfilled exactly. Generally speaking, 

525 the "symmetry" noise can be reduced by running with symmetry='off' in 

526 the ground state calculation. 

527 """ 

528 invmap_GG = get_inverted_pw_mapping(chiks1.qpd, chiks2.qpd) 

529 

530 # Loop over frequencies 

531 for chi1_GG, chi2_GG in zip(chiks1.array, chiks2.array): 

532 # Check the reciprocity 

533 assert chi2_GG[invmap_GG].T == pytest.approx(chi1_GG, 

534 rel=rtol, abs=atol) 

535 # Check inversion symmetry 

536 assert chi2_GG[invmap_GG] == pytest.approx(chi1_GG, rel=rtol, abs=atol) 

537 

538 # Loop over q-vectors 

539 for chiks in [chiks1, chiks2]: 

540 for chiks_GG in chiks.array: # array = chiks_zGG 

541 # Check that the full susceptibility matrix is symmetric 

542 assert chiks_GG.T == pytest.approx(chiks_GG, rel=rtol, abs=atol)