Coverage for gpaw/kpt_descriptor.py: 90%

326 statements  

« prev     ^ index     » next       coverage.py v7.7.1, created at 2025-07-08 00:17 +0000

1# Copyright (C) 2003 CAMP 

2# Please see the accompanying LICENSE file for further information. 

3 

4"""K-point descriptor.""" 

5 

6from __future__ import annotations 

7from typing import Optional, Sequence 

8 

9import numpy as np 

10from ase.calculators.calculator import kptdensity2monkhorstpack 

11from ase.dft.kpoints import get_monkhorst_pack_size_and_offset, monkhorst_pack 

12 

13import gpaw.cgpaw as cgpaw 

14import gpaw.mpi as mpi 

15from gpaw import KPointError 

16from gpaw.typing import Array1D 

17from gpaw.kpoint import KPoint 

18 

19 

20def to1bz(bzk_kc, cell_cv): 

21 """Wrap k-points to 1. BZ. 

22 

23 Return k-points wrapped to the 1. BZ. 

24 

25 bzk_kc: (n,3) ndarray 

26 Array of k-points in units of the reciprocal lattice vectors. 

27 cell_cv: (3,3) ndarray 

28 Unit cell. 

29 """ 

30 

31 B_cv = 2.0 * np.pi * np.linalg.inv(cell_cv).T 

32 K_kv = np.dot(bzk_kc, B_cv) 

33 N_xc = np.indices((3, 3, 3)).reshape((3, 27)).T - 1 

34 G_xv = np.dot(N_xc, B_cv) 

35 

36 bz1k_kc = bzk_kc.copy() 

37 

38 # Find the closest reciprocal lattice vector: 

39 for k, K_v in enumerate(K_kv): 

40 # If a k-point has the same distance to several reciprocal 

41 # lattice vectors, we don't want to pick a random one on the 

42 # basis of numerical noise, so we round off the differences 

43 # between the shortest distances to 6 decimals and chose the 

44 # one with the lowest index. 

45 d = ((G_xv - K_v)**2).sum(1) 

46 x = (d - d.min()).round(6).argmin() 

47 bz1k_kc[k] -= N_xc[x] 

48 

49 return bz1k_kc 

50 

51 

52def kpts2sizeandoffsets(size=None, density=None, gamma=None, even=None, 

53 atoms=None): 

54 """Helper function for selecting k-points. 

55 

56 Use either size or density. 

57 

58 size: 3 ints 

59 Number of k-points. 

60 density: float 

61 K-point density in units of k-points per Ang^-1. 

62 gamma: None or bool 

63 Should the Gamma-point be included? Yes / no / don't care: 

64 True / False / None. 

65 even: None or bool 

66 Should the number of k-points be even? Yes / no / don't care: 

67 True / False / None. 

68 atoms: Atoms object 

69 Needed for calculating k-point density. 

70 

71 """ 

72 

73 if size is None: 

74 if density is None: 

75 size = [1, 1, 1] 

76 else: 

77 size = kptdensity2monkhorstpack(atoms, density, even) 

78 

79 offsets = [0, 0, 0] 

80 

81 if gamma is not None: 

82 for i, s in enumerate(size): 

83 if atoms.pbc[i] and s % 2 != bool(gamma): 

84 offsets[i] = 0.5 / s 

85 

86 return size, offsets 

87 

88 

89class KPointDescriptor: 

90 """Descriptor-class for k-points.""" 

91 

92 def __init__(self, kpts, nspins: int = 1): 

93 """Construct descriptor object for kpoint/spin combinations (ks-pair). 

94 

95 Parameters: 

96 

97 kpts: None, sequence of 3 ints, or (n,3)-shaped array 

98 Specification of the k-point grid. None=Gamma, list of 

99 ints=Monkhorst-Pack, ndarray=user specified. 

100 nspins: int 

101 Number of spins. 

102 

103 Attributes 

104 =================== ================================================= 

105 ``N_c`` Number of k-points in the different directions. 

106 ``nspins`` Number of spins in total. 

107 ``mynspins`` Number of spins on this CPU. 

108 ``nibzkpts`` Number of irreducible kpoints in 1st BZ. 

109 ``mynks`` Number of k-point/spin combinations on this CPU. 

110 ``gamma`` Boolean indicator for gamma point calculation. 

111 ``comm`` MPI-communicator for kpoint distribution. 

112 ``weight_k`` Weights of each k-point 

113 ``ibzk_kc`` Unknown 

114 ``ibzk_qc`` Unknown 

115 ``sym_k`` Unknown 

116 ``time_reversal_k`` Unknown 

117 ``bz2ibz_k`` Unknown 

118 ``ibz2bz_k`` Unknown 

119 ``bz2bz_ks`` Unknown 

120 ``symmetry`` Object representing symmetries 

121 =================== ================================================= 

122 """ 

123 

124 self.N_c: Optional[Array1D] = None 

125 self.offset_c: Optional[Array1D] = None 

126 

127 if kpts is None: 

128 self.bzk_kc = np.zeros((1, 3)) 

129 self.N_c = np.array((1, 1, 1), dtype=int) 

130 self.offset_c = np.zeros(3) 

131 else: 

132 kpts = np.asarray(kpts) 

133 if kpts.ndim == 1: 

134 self.N_c = np.array(kpts, dtype=int) 

135 self.bzk_kc = monkhorst_pack(self.N_c) 

136 self.offset_c = np.zeros(3) 

137 else: 

138 self.bzk_kc = np.array(kpts, dtype=float) 

139 try: 

140 self.N_c, self.offset_c = \ 

141 get_monkhorst_pack_size_and_offset(self.bzk_kc) 

142 except ValueError: 

143 pass 

144 self.nspins = nspins 

145 self.nbzkpts = len(self.bzk_kc) 

146 

147 # Gamma-point calculation? 

148 self.gamma = self.nbzkpts == 1 and not self.bzk_kc.any() 

149 

150 # Point group and time-reversal symmetry neglected: 

151 self.weight_k = np.ones(self.nbzkpts) / self.nbzkpts 

152 self.ibzk_kc = self.bzk_kc.copy() 

153 self.sym_k = np.zeros(self.nbzkpts, int) 

154 self.time_reversal_k = np.zeros(self.nbzkpts, bool) 

155 self.bz2ibz_k = np.arange(self.nbzkpts) 

156 self.ibz2bz_k = np.arange(self.nbzkpts) 

157 self.bz2bz_ks = np.arange(self.nbzkpts)[:, np.newaxis] 

158 self.nibzkpts = self.nbzkpts 

159 self.refine_info = None 

160 self.monkhorst = (self.N_c is not None) 

161 

162 self.set_communicator(mpi.serial_comm) 

163 

164 def __str__(self): 

165 s = str(self.symmetry) 

166 

167 if self.refine_info is not None: 

168 s += '\n' + str(self.refine_info) 

169 

170 if -1 in self.bz2bz_ks: 

171 s += 'Note: your k-points are not as symmetric as your crystal!\n' 

172 

173 if self.gamma: 

174 s += '\n1 k-point (Gamma)' 

175 else: 

176 s += '\n%d k-points' % self.nbzkpts 

177 if self.monkhorst: 

178 s += ': %d x %d x %d Monkhorst-Pack grid' % tuple(self.N_c) 

179 if self.offset_c.any(): 

180 s += ' + [' 

181 for x in self.offset_c: 

182 if x != 0 and abs(round(1 / x) - 1 / x) < 1e-12: 

183 s += '1/%d,' % round(1 / x) 

184 else: 

185 s += '%f,' % x 

186 s = s[:-1] + ']' 

187 

188 s += ('\n%d k-point%s in the irreducible part of the Brillouin zone\n' 

189 % (self.nibzkpts, ' s'[1:self.nibzkpts])) 

190 

191 if self.monkhorst: 

192 w_k = self.weight_k * self.nbzkpts 

193 assert np.allclose(w_k, w_k.round()) 

194 w_k = w_k.round() 

195 

196 s += ' k-points in crystal coordinates weights\n' 

197 for k in range(self.nibzkpts): 

198 if k < 10 or k == self.nibzkpts - 1: 

199 if self.monkhorst: 

200 s += ('%4d: %12.8f %12.8f %12.8f %6d/%d\n' % 

201 ((k,) + tuple(self.ibzk_kc[k]) + 

202 (w_k[k], self.nbzkpts))) 

203 else: 

204 s += ('%4d: %12.8f %12.8f %12.8f %12.8f\n' % 

205 ((k,) + tuple(self.ibzk_kc[k]) + 

206 (self.weight_k[k],))) 

207 elif k == 10: 

208 s += ' ...\n' 

209 return s 

210 

211 def set_symmetry(self, atoms, symmetry, comm=None): 

212 """Create symmetry object and construct irreducible Brillouin zone. 

213 

214 atoms: Atoms object 

215 Defines atom positions and types and also unit cell and 

216 boundary conditions. 

217 symmetry: Symmetry object 

218 Symmetry object. 

219 """ 

220 

221 self.symmetry = symmetry 

222 

223 # XXX we pass the whole atoms object just to complain if its PBCs 

224 # are not how we like them 

225 for c, periodic in enumerate(atoms.pbc): 

226 if not periodic and not np.allclose(self.bzk_kc[:, c], 0.0): 

227 raise ValueError('K-points can only be used with PBCs!') 

228 

229 if symmetry.time_reversal or symmetry.point_group: 

230 (self.ibzk_kc, self.weight_k, 

231 self.sym_k, 

232 self.time_reversal_k, 

233 self.bz2ibz_k, 

234 self.ibz2bz_k, 

235 self.bz2bz_ks) = symmetry.reduce(self.bzk_kc, comm) 

236 

237 # Number of irreducible k-points and k-point/spin combinations. 

238 self.nibzkpts = len(self.ibzk_kc) 

239 

240 def set_communicator(self, comm): 

241 """Set k-point communicator.""" 

242 

243 # Ranks < self.rank0 have mynks0 k-point/spin combinations and 

244 # ranks >= self.rank0 have mynks0+1 k-point/spin combinations. 

245 mynk0, x = divmod(self.nibzkpts, comm.size) 

246 self.rank0 = comm.size - x 

247 self.comm = comm 

248 

249 # My number and offset of k-point/spin combinations 

250 self.mynk = self.get_count() 

251 self.k0 = self.get_offset() 

252 

253 self.ibzk_qc = self.ibzk_kc[self.k0:self.k0 + self.mynk] 

254 self.weight_q = self.weight_k[self.k0:self.k0 + self.mynk] 

255 

256 def copy(self, comm=mpi.serial_comm): 

257 """Create a copy with shared symmetry object.""" 

258 kd = KPointDescriptor(self.bzk_kc, self.nspins) 

259 kd.weight_k = self.weight_k 

260 kd.ibzk_kc = self.ibzk_kc 

261 kd.sym_k = self.sym_k 

262 kd.time_reversal_k = self.time_reversal_k 

263 kd.bz2ibz_k = self.bz2ibz_k 

264 kd.ibz2bz_k = self.ibz2bz_k 

265 kd.bz2bz_ks = self.bz2bz_ks 

266 kd.symmetry = self.symmetry 

267 kd.nibzkpts = self.nibzkpts 

268 kd.set_communicator(comm) 

269 return kd 

270 

271 def create_k_points(self, sdisp_cd, collinear): 

272 """Return a list of KPoints.""" 

273 

274 kpt_qs = [] 

275 

276 for k in range(self.k0, self.k0 + self.mynk): 

277 q = k - self.k0 

278 weightk = self.weight_k[k] 

279 weight = weightk * 2 / self.nspins 

280 if self.gamma: 

281 phase_cd = np.ones((3, 2), complex) 

282 else: 

283 phase_cd = np.exp(2j * np.pi * 

284 sdisp_cd * self.ibzk_kc[k, :, np.newaxis]) 

285 if collinear: 

286 spins = range(self.nspins) 

287 else: 

288 spins = [None] 

289 weight *= 0.5 

290 kpt_qs.append([KPoint(weightk, weight, s, k, q, phase_cd) 

291 for s in spins]) 

292 

293 return kpt_qs 

294 

295 def collect(self, a_ux, broadcast: bool): 

296 """Collect distributed data to all.""" 

297 

298 xshape = a_ux.shape[1:] 

299 a_qsx = a_ux.reshape((-1, self.nspins) + xshape) 

300 if self.comm.rank == 0 or broadcast: 

301 a_ksx = np.empty((self.nibzkpts, self.nspins) + xshape, a_ux.dtype) 

302 

303 if self.comm.rank > 0: 

304 self.comm.send(a_qsx, 0) 

305 else: 

306 k1 = self.get_count(0) 

307 a_ksx[0:k1] = a_qsx 

308 requests = [] 

309 for rank in range(1, self.comm.size): 

310 k2 = k1 + self.get_count(rank) 

311 requests.append(self.comm.receive(a_ksx[k1:k2], rank, 

312 block=False)) 

313 k1 = k2 

314 assert k1 == self.nibzkpts 

315 self.comm.waitall(requests) 

316 

317 if broadcast: 

318 self.comm.broadcast(a_ksx, 0) 

319 

320 if self.comm.rank == 0 or broadcast: 

321 return a_ksx.transpose((1, 0, 2)) 

322 

323 def transform_wave_function(self, psit_G, k, index_G=None, phase_G=None): 

324 """Transform wave function from IBZ to BZ. 

325 

326 k is the index of the desired k-point in the full BZ. 

327 """ 

328 

329 s = self.sym_k[k] 

330 time_reversal = self.time_reversal_k[k] 

331 op_cc = np.linalg.inv(self.symmetry.op_scc[s]).round().astype(int) 

332 

333 # Identity 

334 if (np.abs(op_cc - np.eye(3, dtype=int)) < 1e-10).all(): 

335 if time_reversal: 

336 return psit_G.conj() 

337 else: 

338 return psit_G 

339 # General point group symmetry 

340 else: 

341 ik = self.bz2ibz_k[k] 

342 kibz_c = self.ibzk_kc[ik] 

343 b_g = np.zeros_like(psit_G) 

344 kbz_c = np.dot(self.symmetry.op_scc[s], kibz_c) 

345 if index_G is not None: 

346 assert index_G.shape == psit_G.shape == phase_G.shape 

347 cgpaw.symmetrize_with_index(psit_G, b_g, index_G, phase_G) 

348 else: 

349 cgpaw.symmetrize_wavefunction(psit_G, b_g, op_cc.copy(), 

350 np.ascontiguousarray(kibz_c), 

351 kbz_c) 

352 

353 if time_reversal: 

354 return b_g.conj() 

355 else: 

356 return b_g 

357 

358 def get_transform_wavefunction_index(self, nG, k): 

359 """Get the "wavefunction transform index". 

360 

361 This is a permutation of the numbers 1, 2, .. N which 

362 associates k + q to some k, and where N is the total 

363 number of grid points as specified by nG which is a 

364 3D tuple. 

365 

366 Returns index_G and phase_G which are one-dimensional 

367 arrays on the grid.""" 

368 

369 s = self.sym_k[k] 

370 op_cc = np.linalg.inv(self.symmetry.op_scc[s]).round().astype(int) 

371 

372 # General point group symmetry 

373 if (np.abs(op_cc - np.eye(3, dtype=int)) < 1e-10).all(): 

374 nG0 = np.prod(nG) 

375 index_G = np.arange(nG0).reshape(nG) 

376 phase_G = np.ones(nG) 

377 else: 

378 ik = self.bz2ibz_k[k] 

379 kibz_c = self.ibzk_kc[ik] 

380 index_G = np.zeros(nG, dtype=int) 

381 phase_G = np.zeros(nG, dtype=complex) 

382 

383 kbz_c = np.dot(self.symmetry.op_scc[s], kibz_c) 

384 cgpaw.symmetrize_return_index(index_G, phase_G, op_cc.copy(), 

385 np.ascontiguousarray(kibz_c), 

386 kbz_c) 

387 return index_G, phase_G 

388 

389 def find_k_plus_q(self, q_c, kpts_k: Sequence[int] = None) -> list[int]: 

390 """Find the indices of k+q for all kpoints in the Brillouin zone. 

391 

392 In case that k+q is outside the BZ, the k-point inside the BZ 

393 corresponding to k+q is given. 

394 

395 Parameters 

396 ---------- 

397 q_c: np.ndarray 

398 Coordinates for the q-vector in units of the reciprocal 

399 lattice vectors. 

400 kpts_k: 

401 Restrict search to specified k-points. 

402 

403 """ 

404 k_x = kpts_k 

405 if k_x is None: 

406 return self.find_k_plus_q(q_c, range(self.nbzkpts)) 

407 

408 i_x = [] 

409 for k in k_x: 

410 kpt_c = self.bzk_kc[k] + q_c 

411 d_kc = kpt_c - self.bzk_kc 

412 d_k = abs(d_kc - d_kc.round()).sum(1) 

413 i = d_k.argmin() 

414 if d_k[i] > 1e-8: 

415 raise KPointError('Could not find k+q!') 

416 i_x.append(i) 

417 

418 return i_x 

419 

420 def get_bz_q_points(self, first=False): 

421 """Return the q=k1-k2. q-mesh is always Gamma-centered.""" 

422 shift_c = 0.5 * ((self.N_c + 1) % 2) / self.N_c 

423 bzq_qc = monkhorst_pack(self.N_c) + shift_c 

424 if first: 

425 return to1bz(bzq_qc, self.symmetry.cell_cv) 

426 else: 

427 return bzq_qc 

428 

429 def get_ibz_q_points(self, bzq_qc, op_scc): 

430 """Return ibz q points and the corresponding symmetry operations that 

431 work for k-mesh as well.""" 

432 

433 ibzq_qc_tmp = [] 

434 ibzq_qc_tmp.append(bzq_qc[-1]) 

435 weight_tmp = [0] 

436 

437 for i, op_cc in enumerate(op_scc): 

438 if np.abs(op_cc - np.eye(3)).sum() < 1e-8: 

439 identity_iop = i 

440 break 

441 

442 ibzq_q_tmp = {} 

443 iop_q = {} 

444 timerev_q = {} 

445 diff_qc = {} 

446 

447 for i in range(len(bzq_qc) - 1, -1, -1): # loop opposite to kpoint 

448 try: 

449 ibzk, iop, timerev, diff_c = self.find_ibzkpt( 

450 op_scc, ibzq_qc_tmp, bzq_qc[i]) 

451 find = False 

452 for ii, iop1 in enumerate(self.sym_k): 

453 if iop1 == iop and self.time_reversal_k[ii] == timerev: 

454 find = True 

455 break 

456 if not find: 

457 raise ValueError('cant find k!') 

458 

459 ibzq_q_tmp[i] = ibzk 

460 weight_tmp[ibzk] += 1. 

461 iop_q[i] = iop 

462 timerev_q[i] = timerev 

463 diff_qc[i] = diff_c 

464 except ValueError: 

465 ibzq_qc_tmp.append(bzq_qc[i]) 

466 weight_tmp.append(1.) 

467 ibzq_q_tmp[i] = len(ibzq_qc_tmp) - 1 

468 iop_q[i] = identity_iop 

469 timerev_q[i] = False 

470 diff_qc[i] = np.zeros(3) 

471 

472 # reverse the order. 

473 nq = len(ibzq_qc_tmp) 

474 ibzq_qc = np.zeros((nq, 3)) 

475 ibzq_q = np.zeros(len(bzq_qc), dtype=int) 

476 for i in range(nq): 

477 ibzq_qc[i] = ibzq_qc_tmp[nq - i - 1] 

478 for i in range(len(bzq_qc)): 

479 ibzq_q[i] = nq - ibzq_q_tmp[i] - 1 

480 self.q_weights = np.array(weight_tmp[::-1]) / len(bzq_qc) 

481 return ibzq_qc, ibzq_q, iop_q, timerev_q, diff_qc 

482 

483 def find_ibzkpt(self, symrel, ibzk_kc, bzk_c): 

484 """Find index in IBZ and related symmetry operations.""" 

485 find = False 

486 ibzkpt = 0 

487 iop = 0 

488 timerev = False 

489 

490 for sign in (1, -1): 

491 for ioptmp, op in enumerate(symrel): 

492 for i, ibzk in enumerate(ibzk_kc): 

493 diff_c = bzk_c - sign * np.dot(op, ibzk) 

494 if (np.abs(diff_c - diff_c.round()) < 1e-8).all(): 

495 ibzkpt = i 

496 iop = ioptmp 

497 find = True 

498 if sign == -1: 

499 timerev = True 

500 break 

501 if find: 

502 break 

503 if find: 

504 break 

505 

506 if not find: 

507 raise ValueError('Cant find corresponding IBZ kpoint!') 

508 return ibzkpt, iop, timerev, diff_c.round() 

509 

510 def where_is_q(self, q_c, bzq_qc): 

511 """Find the index of q points in BZ.""" 

512 d_qc = q_c - bzq_qc 

513 d_q = abs(d_qc - d_qc.round()).sum(1) 

514 q = d_q.argmin() 

515 if d_q[q] > 1e-8: 

516 raise KPointError('Could not find q!') 

517 return q 

518 

519 def get_count(self, rank=None): 

520 """Return the number of ks-pairs which belong to a given rank.""" 

521 

522 if rank is None: 

523 rank = self.comm.rank 

524 assert rank in range(self.comm.size) 

525 mynk0 = self.nibzkpts // self.comm.size 

526 mynk = mynk0 

527 if rank >= self.rank0: 

528 mynk += 1 

529 return mynk 

530 

531 def get_offset(self, rank=None): 

532 """Return the offset of the first ks-pair on a given rank.""" 

533 

534 if rank is None: 

535 rank = self.comm.rank 

536 assert rank in range(self.comm.size) 

537 mynk0 = self.nibzkpts // self.comm.size 

538 k0 = rank * mynk0 

539 if rank >= self.rank0: 

540 k0 += rank - self.rank0 

541 return k0 

542 

543 def get_rank_and_index(self, k): 

544 """Find rank and local index of k-point/spin combination.""" 

545 

546 rank, q = self.who_has(k) 

547 return rank, q 

548 

549 def get_indices(self, rank=None): 

550 """Return the global ks-pair indices which belong to a given rank.""" 

551 

552 k1 = self.get_offset(rank) 

553 k2 = k1 + self.get_count(rank) 

554 return np.arange(k1, k2) 

555 

556 def who_has(self, k): 

557 """Convert global index to rank information and local index.""" 

558 

559 mynk0 = self.nibzkpts // self.comm.size 

560 if k < mynk0 * self.rank0: 

561 rank, q = divmod(k, mynk0) 

562 else: 

563 rank, q = divmod(k - mynk0 * self.rank0, mynk0 + 1) 

564 rank += self.rank0 

565 return rank, q 

566 

567 def write(self, writer): 

568 writer.write('ibzkpts', self.ibzk_kc) 

569 writer.write('bzkpts', self.bzk_kc) 

570 writer.write('bz2ibz', self.bz2ibz_k) 

571 writer.write('weights', self.weight_k) 

572 writer.write('rotations', self.symmetry.op_scc) 

573 writer.write('translations', self.symmetry.ft_sc) 

574 writer.write('atommap', self.symmetry.a_sa)