Coverage for gpaw/berryphase.py: 89%

289 statements  

« prev     ^ index     » next       coverage.py v7.7.1, created at 2025-07-14 00:18 +0000

1from __future__ import annotations 

2import warnings 

3from pathlib import Path 

4 

5import numpy as np 

6from ase.dft.bandgap import bandgap 

7from ase.dft.kpoints import get_monkhorst_pack_size_and_offset 

8 

9from gpaw import GPAW 

10from gpaw.ibz2bz import get_overlap 

11from gpaw.ibz2bz import (get_overlap_coefficients, 

12 get_phase_shifted_overlap_coefficients) 

13from gpaw.mpi import rank, serial_comm, world 

14from gpaw.spinorbit import soc_eigenstates 

15from gpaw.utilities.blas import gemmdot 

16 

17from ase import Atoms 

18from ase.parallel import parprint 

19 

20 

21class ZeroBandgap(Exception): 

22 pass 

23 

24 

25def get_berry_phases(calc, spin=0, dir=0, check2d=False): 

26 if isinstance(calc, (str, Path)): 

27 calc = GPAW(calc, communicator=serial_comm, txt=None) 

28 

29 assert len(calc.symmetry.op_scc) == 1 # does not work with symmetry 

30 gap = bandgap(calc)[0] 

31 

32 if gap == 0.0: 

33 raise ZeroBandgap( 

34 'Berry-phase calculation requires non-zero band gap.') 

35 

36 M_raw = calc.get_magnetic_moment() 

37 M = np.round(M_raw) 

38 assert np.allclose(M, M_raw, atol=0.05), f'Non-integer magmom {M_raw}' 

39 nvalence = calc.wfs.setups.nvalence 

40 nocc_s = [int((nvalence + M) / 2), int((nvalence - M) / 2)] 

41 nocc = nocc_s[spin] 

42 if not calc.wfs.collinear: 

43 nocc = nvalence 

44 else: 

45 assert np.allclose(np.sum(nocc_s), nvalence) 

46 

47 bands = list(range(nocc)) 

48 kpts_kc = calc.get_bz_k_points() 

49 size = get_monkhorst_pack_size_and_offset(kpts_kc)[0] 

50 Nk = len(kpts_kc) 

51 wfs = calc.wfs 

52 

53 dO_aii = get_overlap_coefficients(wfs) 

54 

55 kd = calc.wfs.kd 

56 

57 u_knR = [] 

58 proj_k = [] 

59 for k in range(Nk): 

60 ik = kd.bz2ibz_k[k] 

61 k_c = kd.bzk_kc[k] 

62 ik_c = kd.ibzk_kc[ik] 

63 # Since symmetry is off this should always hold 

64 assert np.allclose(k_c, ik_c) 

65 kpt = wfs.kpt_qs[ik][spin] 

66 

67 # Check that all states are occupied 

68 assert np.all(kpt.f_n[:nocc] > 1e-6) 

69 N_c = wfs.gd.N_c 

70 

71 ut_nR = [] 

72 psit_nG = kpt.psit_nG 

73 for n in range(nocc): 

74 if wfs.collinear: 

75 ut_nR.append(wfs.pd.ifft(psit_nG[n], ik)) 

76 else: 

77 ut0_R = wfs.pd.ifft(psit_nG[n][0], ik) 

78 ut1_R = wfs.pd.ifft(psit_nG[n][1], ik) 

79 # Here R includes a spinor index 

80 ut_nR.append([ut0_R, ut1_R]) 

81 

82 u_knR.append(ut_nR) 

83 proj_k.append(kpt.projections) 

84 

85 indices_kkk = np.arange(Nk).reshape(size) 

86 tmp = np.concatenate([[i for i in range(3) if i != dir], [dir]]) 

87 indices_kk = indices_kkk.transpose(tmp).reshape(-1, size[dir]) 

88 

89 nkperp = len(indices_kk) 

90 phases = [] 

91 if check2d: 

92 phases2d = [] 

93 # plane average of overlaps M_nm = <u_nk | u_mk+q> 

94 # on k-plane with normal vector in dir 

95 for indices_k in indices_kk: 

96 M_knn = [] 

97 for j in range(size[dir]): 

98 k1 = indices_k[j] 

99 G_c = np.array([0, 0, 0]) 

100 if j + 1 < size[dir]: 

101 k2 = indices_k[j + 1] 

102 else: 

103 k2 = indices_k[0] 

104 # pbc: 

105 # psi_k(r) = psi_k+G(r) -> u_k(r) = e^-iGr u_k+G(r) 

106 G_c[dir] = 1 

107 u1_nR = np.array(u_knR[k1]) 

108 u2_nR = np.array(u_knR[k2]) 

109 k1_c = kpts_kc[k1] 

110 k2_c = kpts_kc[k2] + G_c 

111 

112 if np.any(G_c): 

113 # pick up e^iGr 

114 emiGr_R = np.exp(-2j * np.pi * 

115 np.dot(np.indices(N_c).T, G_c / N_c).T) 

116 u2_nR = u2_nR * emiGr_R 

117 

118 bG_c = k2_c - k1_c 

119 

120 phase_shifted_dO_aii = get_phase_shifted_overlap_coefficients( 

121 dO_aii, calc.spos_ac, -bG_c) 

122 

123 # < u_nk | u_mk+1 > 

124 M_nn = get_overlap(bands, wfs.gd, u1_nR, u2_nR, 

125 proj_k[k1], proj_k[k2], phase_shifted_dO_aii) 

126 M_knn.append(M_nn) 

127 

128 # det_k = det(k, nbands, nbands) 

129 det_k = np.linalg.det(M_knn) 

130 phases.append(np.imag(np.log(np.prod(det_k)))) 

131 if check2d: 

132 # In the case of 2D systems we can check the 

133 # result 

134 k1 = indices_k[0] 

135 k1_c = kpts_kc[k1] 

136 G_c = [0, 0, 1] 

137 u1_nR = u_knR[k1] 

138 emiGr_R = np.exp(-2j * np.pi * 

139 np.dot(np.indices(N_c).T, G_c / N_c).T) 

140 u2_nR = u1_nR * emiGr_R 

141 

142 phase_shifted_dO_aii = get_phase_shifted_overlap_coefficients( 

143 dO_aii, calc.spos_ac, -bG_c) 

144 M_nn = get_overlap(bands, calc.wfs.gd, u1_nR, u2_nR, 

145 proj_k[k1], proj_k[k1], phase_shifted_dO_aii) 

146 

147 phase2d = np.imag(np.log(np.linalg.det(M_nn))) 

148 phases2d.append(phase2d) 

149 

150 # Make sure the phases are continuous 

151 for p in range(nkperp - 1): 

152 delta = phases[p] - phases[p + 1] 

153 phases[p + 1] += np.round(delta / (2 * np.pi)) * 2 * np.pi 

154 

155 # plane average over all perpendicular kpoints to direction dir 

156 phase = np.sum(phases) / nkperp 

157 if check2d: 

158 for p in range(nkperp - 1): 

159 delta = phases2d[p] - phases2d[p + 1] 

160 phases2d[p + 1] += np.round(delta / (2 * np.pi)) * 2 * np.pi 

161 

162 phase2d = np.sum(phases2d) / nkperp 

163 

164 diff = abs(phase - phase2d) 

165 if diff > 0.01: 

166 msg = 'Warning wrong phase: phase={}, 2dphase={}' 

167 print(msg.format(phase, phase2d)) 

168 

169 return indices_kk, phases 

170 

171 

172def polarization_phase(gpw_wfs: Path, comm, cleanup: bool = False): 

173 """ 

174 

175 Polarization phase based on evaluation of 

176 Berry-phase and ionic polarization 

177 

178 Electrical polarization 

179 [Raffaele Resta and David Vanderbilt in 

180 Physics of Ferroelectrics]: 

181 

182 P_v = e/(2 * pi)^3 sum_n phi_nv + e/vol sum_a Z_a * r_av 

183 

184 with Berry phase in cartesian coordinates 

185 

186 phi_nv = Im(int_BZ dk^2 dk_v <u_nk | d/dk_v | u_nk>) 

187 

188 = Im( ln prod_{j=0}^{M-1} <u_n,k_j | u_n,k_j+1> ) 

189 

190 evaluated for each dimension as the product of bloch function overlaps. 

191 

192 Here we evaluate the polarization phase given by 

193 

194 phi_v = 2 * pi * vol * P_v 

195 

196 """ 

197 

198 # calculation in serial only on master 

199 if comm.rank == 0: 

200 phases_c = _get_phases(gpw_wfs, cleanup=cleanup) 

201 else: 

202 phases_c = { 

203 'phase_c': np.empty(3), 

204 'electronic_phase_c': np.empty(3), 

205 'atomic_phase_c': np.empty(3), 

206 'dipole_phase_c': np.empty(3), 

207 } 

208 

209 # broadcast 

210 for key in phases_c: 

211 comm.broadcast(phases_c[key], 0) 

212 

213 return phases_c 

214 

215 

216def _get_phases(gpw_wfs: Path, cleanup: bool = False): 

217 parprint(f'Reading wfs from {gpw_wfs}') 

218 calc = GPAW(gpw_wfs, communicator=serial_comm, txt=None) 

219 atoms = calc.get_atoms() 

220 

221 parprint('Calculating polarization') 

222 electronic_phase_c = get_electronic_polarization_phase(calc) 

223 # valence electron number for each atom 

224 Nv_a = [setup.Nv for setup in calc.setups] 

225 atomic_phase_c = get_atomic_polarization_phase(Nv_a, calc.spos_ac) 

226 dipole_v = calc.get_dipole_moment() 

227 cell_cv = atoms.get_cell() 

228 dipole_phase_c = get_dipole_polarization_phase(dipole_v, cell_cv) 

229 

230 # total phase 

231 pbc_c = atoms.get_pbc() 

232 phase_c = electronic_phase_c + atomic_phase_c 

233 phase_c[~pbc_c] = dipole_phase_c[~pbc_c] 

234 

235 # remove file gpw_wfs 

236 if cleanup: 

237 gpw_wfs.unlink() 

238 

239 phases_c = { 

240 'phase_c': phase_c, 

241 'electronic_phase_c': electronic_phase_c, 

242 'atomic_phase_c': atomic_phase_c, 

243 'dipole_phase_c': dipole_phase_c, 

244 } 

245 

246 return phases_c 

247 

248 

249def ionic_phase(atoms: Atoms): 

250 # routine to check born charge implementation 

251 # no charge neutrality -> acoustic sum rule not valid 

252 

253 Nv_a = atoms.numbers 

254 spos_ac = atoms.get_scaled_positions() 

255 atomic_phase_c = get_atomic_polarization_phase(Nv_a, spos_ac) 

256 

257 results = { 

258 'phase_c': atomic_phase_c, 

259 'atomic_phase_c': atomic_phase_c, 

260 } 

261 

262 return results 

263 

264 

265def get_electronic_polarization_phase(calc): 

266 from gpaw.berryphase import get_berry_phases 

267 

268 assert calc.world.size == 1 

269 

270 phase_c = np.zeros((3,), float) 

271 # calculate and save berry phases 

272 nspins = calc.get_number_of_spins() 

273 for c in [0, 1, 2]: 

274 for spin in range(nspins): 

275 _, phases = get_berry_phases(calc, dir=c, spin=spin) 

276 phase_c[c] += np.sum(phases) / len(phases) 

277 

278 # non-collinear 

279 nc = 1 - calc.wfs.collinear 

280 # we should not multiply by two below if non-collinear 

281 phase_c = phase_c * (2 - nc) / nspins 

282 

283 return phase_c 

284 

285 

286def get_atomic_polarization_phase(Nv_a, spos_ac): 

287 return 2 * np.pi * np.dot(Nv_a, spos_ac) 

288 

289 

290def get_dipole_polarization_phase(dipole_v, cell_cv): 

291 B_cv = np.linalg.inv(cell_cv).T * 2 * np.pi 

292 dipole_phase_c = np.dot(B_cv, dipole_v) 

293 return dipole_phase_c 

294 

295 

296def parallel_transport(calc, direction=0, name=None, scale=1.0, bands=None, 

297 theta=0.0, phi=0.0, comm=None): 

298 """ 

299 Parallel transport. 

300 The parallel transport algorithm corresponds to the construction 

301 of hybrid Wannier functions localized along the Nloc direction. 

302 While these are not constructed explicitly one may obtain the 

303 Wannier Charge centers which are given by the eigenvalues of 

304 the Berry phase matrix (except for a factor of 2*pi) phi_km. 

305 In addition, one may evaluate the expectation value of spin 

306 on each of these states along the easy axis (z-axis for 

307 nonmagnetic systems), which is given by S_km. 

308 

309 Output: 

310 phi_km, S_km (see above) 

311 """ 

312 comm = comm or world 

313 

314 if isinstance(calc, str): 

315 calc = GPAW(calc, txt=None, communicator=serial_comm) 

316 

317 if bands is None: 

318 nv = int(calc.get_number_of_electrons()) 

319 bands = range(nv) 

320 

321 cell_cv = calc.wfs.gd.cell_cv 

322 icell_cv = (2 * np.pi) * np.linalg.inv(cell_cv).T 

323 r_g = calc.wfs.gd.get_grid_point_coordinates() 

324 

325 dO_aii = get_overlap_coefficients(calc.wfs) 

326 

327 N_c = calc.wfs.kd.N_c 

328 assert 1 in np.delete(N_c, direction) 

329 Nkx = N_c[0] 

330 Nky = N_c[1] 

331 Nkz = N_c[2] 

332 

333 Nk = Nkx * Nky * Nkz 

334 Nloc = N_c[direction] 

335 Npar = Nk // Nloc 

336 

337 # Parallelization stuff 

338 myKsize = -(-Npar // (comm.size)) 

339 myKrange = range(rank * myKsize, min((rank + 1) * myKsize, Npar)) 

340 myKsize = len(myKrange) 

341 

342 # Get array of k-point indices of the path. q index is loc direction 

343 kpts_kq = [] 

344 for k in range(Npar): 

345 if direction == 0: 

346 kpts_kq.append(list(range(k, Nkx * Nky, Nky))) 

347 if direction == 1: 

348 if Nkz == 1: 

349 kpts_kq.append(list(range(k * Nky, (k + 1) * Nky))) 

350 else: 

351 kpts_kq.append(list(range(k, Nkz * Nky, Nkz))) 

352 if direction == 2: 

353 kpts_kq.append(list(range(k * Nloc, (k + 1) * Nloc))) 

354 

355 G_c = np.array([0, 0, 0]) 

356 G_c[direction] = 1 

357 G_v = np.dot(G_c, icell_cv) 

358 

359 kpts_kc = calc.get_bz_k_points() 

360 

361 if Nloc > 1: 

362 b_c = kpts_kc[kpts_kq[0][1]] - kpts_kc[kpts_kq[0][0]] 

363 else: 

364 b_c = G_c 

365 phase_shifted_dO_aii = get_phase_shifted_overlap_coefficients( 

366 dO_aii, calc.spos_ac, -b_c) 

367 

368 soc_kpts = soc_eigenstates(calc, scale=scale, theta=theta, phi=phi) 

369 

370 def projections(bz_index): 

371 proj = soc_kpts[bz_index].projections 

372 new_proj = proj.new() 

373 new_proj.matrix.array = proj.matrix.array.copy() 

374 return new_proj 

375 

376 def wavefunctions(bz_index): 

377 return soc_kpts[bz_index].wavefunctions(calc, periodic=True) 

378 

379 phi_km = np.zeros((Npar, len(bands)), float) 

380 S_km = np.zeros((Npar, len(bands)), float) 

381 # Loop over the direction parallel components 

382 for k in myKrange: 

383 U_qmm = [np.eye(len(bands))] 

384 qpts_q = kpts_kq[k] 

385 # Loop over kpoints in the phase direction 

386 for q in range(Nloc - 1): 

387 iq1 = qpts_q[q] 

388 iq2 = qpts_q[q + 1] 

389 # print(kpts_kc[iq1], kpts_kc[iq2]) 

390 if q == 0: 

391 u1_nsG = wavefunctions(iq1) 

392 proj1 = projections(iq1) 

393 

394 u2_nsG = wavefunctions(iq2) 

395 proj2 = projections(iq2) 

396 

397 M_mm = get_overlap(bands, calc.wfs.gd, u1_nsG, u2_nsG, 

398 proj1, proj2, phase_shifted_dO_aii) 

399 

400 V_mm, sing_m, W_mm = np.linalg.svd(M_mm) 

401 U_mm = np.dot(V_mm, W_mm).conj() 

402 u_mysxz = np.dot(U_mm, np.swapaxes(u2_nsG[bands], 0, 3)) 

403 u_mxsyz = np.swapaxes(u_mysxz, 1, 3) 

404 u_msxyz = np.swapaxes(u_mxsyz, 1, 2) 

405 u2_nsG[bands] = u_msxyz 

406 for a in range(len(calc.atoms)): 

407 assert not proj2.collinear 

408 P2_msi = proj2[a][bands] 

409 for s in range(2): 

410 P2_mi = P2_msi[:, s] 

411 P2_mi = np.dot(U_mm, P2_mi) 

412 P2_msi[:, s] = P2_mi 

413 proj2[a][bands] = P2_msi 

414 U_qmm.append(U_mm) 

415 u1_nsG = u2_nsG 

416 proj1 = proj2 

417 U_qmm = np.array(U_qmm) 

418 

419 # Fix phases for last point 

420 iq0 = qpts_q[0] 

421 if Nloc == 1: 

422 u1_nsG = wavefunctions(iq0) 

423 proj1 = projections(iq0) 

424 u2_nsG = wavefunctions(iq0) 

425 u2_nsG[:] *= np.exp(-1.0j * gemmdot(G_v, r_g, beta=0.0)) 

426 proj2 = projections(iq0) 

427 

428 M_mm = get_overlap(bands, calc.wfs.gd, u1_nsG, u2_nsG, 

429 proj1, proj2, phase_shifted_dO_aii) 

430 

431 V_mm, sing_m, W_mm = np.linalg.svd(M_mm) 

432 U_mm = np.dot(V_mm, W_mm).conj() 

433 u_mysxz = np.dot(U_mm, np.swapaxes(u2_nsG[bands], 0, 3)) 

434 u_mxsyz = np.swapaxes(u_mysxz, 1, 3) 

435 u_msxyz = np.swapaxes(u_mxsyz, 1, 2) 

436 u2_nsG[bands] = u_msxyz 

437 for a in range(len(calc.atoms)): 

438 assert not proj2.collinear 

439 P2_msi = proj2[a][bands] 

440 for s in range(2): 

441 P2_mi = P2_msi[:, s] 

442 P2_mi = np.dot(U_mm, P2_mi) 

443 P2_msi[:, s] = P2_mi 

444 proj2[a][bands] = P2_msi 

445 

446 # Get overlap between first kpts and its smoothly translated image 

447 u2_nsG[:] *= np.exp(1.0j * gemmdot(G_v, r_g, beta=0.0)) 

448 u1_nsG = wavefunctions(iq0) 

449 proj1 = projections(iq0) 

450 M_mm = get_overlap(bands, calc.wfs.gd, u1_nsG, 

451 u2_nsG, proj1, proj2, dO_aii) 

452 

453 l_m, l_mm = np.linalg.eig(M_mm) 

454 phi_km[k] = np.angle(l_m) 

455 

456 A_mm = np.zeros_like(l_mm, complex) 

457 for q in range(Nloc): 

458 iq = qpts_q[q] 

459 U_mm = U_qmm[q] 

460 v_mn = soc_kpts[iq].v_mn 

461 v_nm = np.einsum('xm, mn -> nx', U_mm, v_mn[bands]) 

462 A_mm += np.dot(v_nm[::2].T.conj(), v_nm[::2]) 

463 A_mm -= np.dot(v_nm[1::2].T.conj(), v_nm[1::2]) 

464 A_mm /= Nloc 

465 S_km[k] = np.diag(l_mm.T.conj().dot(A_mm).dot(l_mm)).real 

466 

467 comm.sum(phi_km) 

468 comm.sum(S_km) 

469 

470 if not calc.density.collinear: 

471 warnings.warn('WARNING: Spin projections are not meaningful ' 

472 + 'for non-collinear calculations') 

473 

474 if name is not None: 

475 if comm.rank == 0: 

476 np.savez(f'phases_{name}.npz', phi_km=phi_km, S_km=S_km) 

477 comm.barrier() 

478 

479 return phi_km, S_km