Coverage for gpaw/kpt_descriptor.py: 90%
326 statements
« prev ^ index » next coverage.py v7.7.1, created at 2025-07-08 00:17 +0000
« prev ^ index » next coverage.py v7.7.1, created at 2025-07-08 00:17 +0000
1# Copyright (C) 2003 CAMP
2# Please see the accompanying LICENSE file for further information.
4"""K-point descriptor."""
6from __future__ import annotations
7from typing import Optional, Sequence
9import numpy as np
10from ase.calculators.calculator import kptdensity2monkhorstpack
11from ase.dft.kpoints import get_monkhorst_pack_size_and_offset, monkhorst_pack
13import gpaw.cgpaw as cgpaw
14import gpaw.mpi as mpi
15from gpaw import KPointError
16from gpaw.typing import Array1D
17from gpaw.kpoint import KPoint
20def to1bz(bzk_kc, cell_cv):
21 """Wrap k-points to 1. BZ.
23 Return k-points wrapped to the 1. BZ.
25 bzk_kc: (n,3) ndarray
26 Array of k-points in units of the reciprocal lattice vectors.
27 cell_cv: (3,3) ndarray
28 Unit cell.
29 """
31 B_cv = 2.0 * np.pi * np.linalg.inv(cell_cv).T
32 K_kv = np.dot(bzk_kc, B_cv)
33 N_xc = np.indices((3, 3, 3)).reshape((3, 27)).T - 1
34 G_xv = np.dot(N_xc, B_cv)
36 bz1k_kc = bzk_kc.copy()
38 # Find the closest reciprocal lattice vector:
39 for k, K_v in enumerate(K_kv):
40 # If a k-point has the same distance to several reciprocal
41 # lattice vectors, we don't want to pick a random one on the
42 # basis of numerical noise, so we round off the differences
43 # between the shortest distances to 6 decimals and chose the
44 # one with the lowest index.
45 d = ((G_xv - K_v)**2).sum(1)
46 x = (d - d.min()).round(6).argmin()
47 bz1k_kc[k] -= N_xc[x]
49 return bz1k_kc
52def kpts2sizeandoffsets(size=None, density=None, gamma=None, even=None,
53 atoms=None):
54 """Helper function for selecting k-points.
56 Use either size or density.
58 size: 3 ints
59 Number of k-points.
60 density: float
61 K-point density in units of k-points per Ang^-1.
62 gamma: None or bool
63 Should the Gamma-point be included? Yes / no / don't care:
64 True / False / None.
65 even: None or bool
66 Should the number of k-points be even? Yes / no / don't care:
67 True / False / None.
68 atoms: Atoms object
69 Needed for calculating k-point density.
71 """
73 if size is None:
74 if density is None:
75 size = [1, 1, 1]
76 else:
77 size = kptdensity2monkhorstpack(atoms, density, even)
79 offsets = [0, 0, 0]
81 if gamma is not None:
82 for i, s in enumerate(size):
83 if atoms.pbc[i] and s % 2 != bool(gamma):
84 offsets[i] = 0.5 / s
86 return size, offsets
89class KPointDescriptor:
90 """Descriptor-class for k-points."""
92 def __init__(self, kpts, nspins: int = 1):
93 """Construct descriptor object for kpoint/spin combinations (ks-pair).
95 Parameters:
97 kpts: None, sequence of 3 ints, or (n,3)-shaped array
98 Specification of the k-point grid. None=Gamma, list of
99 ints=Monkhorst-Pack, ndarray=user specified.
100 nspins: int
101 Number of spins.
103 Attributes
104 =================== =================================================
105 ``N_c`` Number of k-points in the different directions.
106 ``nspins`` Number of spins in total.
107 ``mynspins`` Number of spins on this CPU.
108 ``nibzkpts`` Number of irreducible kpoints in 1st BZ.
109 ``mynks`` Number of k-point/spin combinations on this CPU.
110 ``gamma`` Boolean indicator for gamma point calculation.
111 ``comm`` MPI-communicator for kpoint distribution.
112 ``weight_k`` Weights of each k-point
113 ``ibzk_kc`` Unknown
114 ``ibzk_qc`` Unknown
115 ``sym_k`` Unknown
116 ``time_reversal_k`` Unknown
117 ``bz2ibz_k`` Unknown
118 ``ibz2bz_k`` Unknown
119 ``bz2bz_ks`` Unknown
120 ``symmetry`` Object representing symmetries
121 =================== =================================================
122 """
124 self.N_c: Optional[Array1D] = None
125 self.offset_c: Optional[Array1D] = None
127 if kpts is None:
128 self.bzk_kc = np.zeros((1, 3))
129 self.N_c = np.array((1, 1, 1), dtype=int)
130 self.offset_c = np.zeros(3)
131 else:
132 kpts = np.asarray(kpts)
133 if kpts.ndim == 1:
134 self.N_c = np.array(kpts, dtype=int)
135 self.bzk_kc = monkhorst_pack(self.N_c)
136 self.offset_c = np.zeros(3)
137 else:
138 self.bzk_kc = np.array(kpts, dtype=float)
139 try:
140 self.N_c, self.offset_c = \
141 get_monkhorst_pack_size_and_offset(self.bzk_kc)
142 except ValueError:
143 pass
144 self.nspins = nspins
145 self.nbzkpts = len(self.bzk_kc)
147 # Gamma-point calculation?
148 self.gamma = self.nbzkpts == 1 and not self.bzk_kc.any()
150 # Point group and time-reversal symmetry neglected:
151 self.weight_k = np.ones(self.nbzkpts) / self.nbzkpts
152 self.ibzk_kc = self.bzk_kc.copy()
153 self.sym_k = np.zeros(self.nbzkpts, int)
154 self.time_reversal_k = np.zeros(self.nbzkpts, bool)
155 self.bz2ibz_k = np.arange(self.nbzkpts)
156 self.ibz2bz_k = np.arange(self.nbzkpts)
157 self.bz2bz_ks = np.arange(self.nbzkpts)[:, np.newaxis]
158 self.nibzkpts = self.nbzkpts
159 self.refine_info = None
160 self.monkhorst = (self.N_c is not None)
162 self.set_communicator(mpi.serial_comm)
164 def __str__(self):
165 s = str(self.symmetry)
167 if self.refine_info is not None:
168 s += '\n' + str(self.refine_info)
170 if -1 in self.bz2bz_ks:
171 s += 'Note: your k-points are not as symmetric as your crystal!\n'
173 if self.gamma:
174 s += '\n1 k-point (Gamma)'
175 else:
176 s += '\n%d k-points' % self.nbzkpts
177 if self.monkhorst:
178 s += ': %d x %d x %d Monkhorst-Pack grid' % tuple(self.N_c)
179 if self.offset_c.any():
180 s += ' + ['
181 for x in self.offset_c:
182 if x != 0 and abs(round(1 / x) - 1 / x) < 1e-12:
183 s += '1/%d,' % round(1 / x)
184 else:
185 s += '%f,' % x
186 s = s[:-1] + ']'
188 s += ('\n%d k-point%s in the irreducible part of the Brillouin zone\n'
189 % (self.nibzkpts, ' s'[1:self.nibzkpts]))
191 if self.monkhorst:
192 w_k = self.weight_k * self.nbzkpts
193 assert np.allclose(w_k, w_k.round())
194 w_k = w_k.round()
196 s += ' k-points in crystal coordinates weights\n'
197 for k in range(self.nibzkpts):
198 if k < 10 or k == self.nibzkpts - 1:
199 if self.monkhorst:
200 s += ('%4d: %12.8f %12.8f %12.8f %6d/%d\n' %
201 ((k,) + tuple(self.ibzk_kc[k]) +
202 (w_k[k], self.nbzkpts)))
203 else:
204 s += ('%4d: %12.8f %12.8f %12.8f %12.8f\n' %
205 ((k,) + tuple(self.ibzk_kc[k]) +
206 (self.weight_k[k],)))
207 elif k == 10:
208 s += ' ...\n'
209 return s
211 def set_symmetry(self, atoms, symmetry, comm=None):
212 """Create symmetry object and construct irreducible Brillouin zone.
214 atoms: Atoms object
215 Defines atom positions and types and also unit cell and
216 boundary conditions.
217 symmetry: Symmetry object
218 Symmetry object.
219 """
221 self.symmetry = symmetry
223 # XXX we pass the whole atoms object just to complain if its PBCs
224 # are not how we like them
225 for c, periodic in enumerate(atoms.pbc):
226 if not periodic and not np.allclose(self.bzk_kc[:, c], 0.0):
227 raise ValueError('K-points can only be used with PBCs!')
229 if symmetry.time_reversal or symmetry.point_group:
230 (self.ibzk_kc, self.weight_k,
231 self.sym_k,
232 self.time_reversal_k,
233 self.bz2ibz_k,
234 self.ibz2bz_k,
235 self.bz2bz_ks) = symmetry.reduce(self.bzk_kc, comm)
237 # Number of irreducible k-points and k-point/spin combinations.
238 self.nibzkpts = len(self.ibzk_kc)
240 def set_communicator(self, comm):
241 """Set k-point communicator."""
243 # Ranks < self.rank0 have mynks0 k-point/spin combinations and
244 # ranks >= self.rank0 have mynks0+1 k-point/spin combinations.
245 mynk0, x = divmod(self.nibzkpts, comm.size)
246 self.rank0 = comm.size - x
247 self.comm = comm
249 # My number and offset of k-point/spin combinations
250 self.mynk = self.get_count()
251 self.k0 = self.get_offset()
253 self.ibzk_qc = self.ibzk_kc[self.k0:self.k0 + self.mynk]
254 self.weight_q = self.weight_k[self.k0:self.k0 + self.mynk]
256 def copy(self, comm=mpi.serial_comm):
257 """Create a copy with shared symmetry object."""
258 kd = KPointDescriptor(self.bzk_kc, self.nspins)
259 kd.weight_k = self.weight_k
260 kd.ibzk_kc = self.ibzk_kc
261 kd.sym_k = self.sym_k
262 kd.time_reversal_k = self.time_reversal_k
263 kd.bz2ibz_k = self.bz2ibz_k
264 kd.ibz2bz_k = self.ibz2bz_k
265 kd.bz2bz_ks = self.bz2bz_ks
266 kd.symmetry = self.symmetry
267 kd.nibzkpts = self.nibzkpts
268 kd.set_communicator(comm)
269 return kd
271 def create_k_points(self, sdisp_cd, collinear):
272 """Return a list of KPoints."""
274 kpt_qs = []
276 for k in range(self.k0, self.k0 + self.mynk):
277 q = k - self.k0
278 weightk = self.weight_k[k]
279 weight = weightk * 2 / self.nspins
280 if self.gamma:
281 phase_cd = np.ones((3, 2), complex)
282 else:
283 phase_cd = np.exp(2j * np.pi *
284 sdisp_cd * self.ibzk_kc[k, :, np.newaxis])
285 if collinear:
286 spins = range(self.nspins)
287 else:
288 spins = [None]
289 weight *= 0.5
290 kpt_qs.append([KPoint(weightk, weight, s, k, q, phase_cd)
291 for s in spins])
293 return kpt_qs
295 def collect(self, a_ux, broadcast: bool):
296 """Collect distributed data to all."""
298 xshape = a_ux.shape[1:]
299 a_qsx = a_ux.reshape((-1, self.nspins) + xshape)
300 if self.comm.rank == 0 or broadcast:
301 a_ksx = np.empty((self.nibzkpts, self.nspins) + xshape, a_ux.dtype)
303 if self.comm.rank > 0:
304 self.comm.send(a_qsx, 0)
305 else:
306 k1 = self.get_count(0)
307 a_ksx[0:k1] = a_qsx
308 requests = []
309 for rank in range(1, self.comm.size):
310 k2 = k1 + self.get_count(rank)
311 requests.append(self.comm.receive(a_ksx[k1:k2], rank,
312 block=False))
313 k1 = k2
314 assert k1 == self.nibzkpts
315 self.comm.waitall(requests)
317 if broadcast:
318 self.comm.broadcast(a_ksx, 0)
320 if self.comm.rank == 0 or broadcast:
321 return a_ksx.transpose((1, 0, 2))
323 def transform_wave_function(self, psit_G, k, index_G=None, phase_G=None):
324 """Transform wave function from IBZ to BZ.
326 k is the index of the desired k-point in the full BZ.
327 """
329 s = self.sym_k[k]
330 time_reversal = self.time_reversal_k[k]
331 op_cc = np.linalg.inv(self.symmetry.op_scc[s]).round().astype(int)
333 # Identity
334 if (np.abs(op_cc - np.eye(3, dtype=int)) < 1e-10).all():
335 if time_reversal:
336 return psit_G.conj()
337 else:
338 return psit_G
339 # General point group symmetry
340 else:
341 ik = self.bz2ibz_k[k]
342 kibz_c = self.ibzk_kc[ik]
343 b_g = np.zeros_like(psit_G)
344 kbz_c = np.dot(self.symmetry.op_scc[s], kibz_c)
345 if index_G is not None:
346 assert index_G.shape == psit_G.shape == phase_G.shape
347 cgpaw.symmetrize_with_index(psit_G, b_g, index_G, phase_G)
348 else:
349 cgpaw.symmetrize_wavefunction(psit_G, b_g, op_cc.copy(),
350 np.ascontiguousarray(kibz_c),
351 kbz_c)
353 if time_reversal:
354 return b_g.conj()
355 else:
356 return b_g
358 def get_transform_wavefunction_index(self, nG, k):
359 """Get the "wavefunction transform index".
361 This is a permutation of the numbers 1, 2, .. N which
362 associates k + q to some k, and where N is the total
363 number of grid points as specified by nG which is a
364 3D tuple.
366 Returns index_G and phase_G which are one-dimensional
367 arrays on the grid."""
369 s = self.sym_k[k]
370 op_cc = np.linalg.inv(self.symmetry.op_scc[s]).round().astype(int)
372 # General point group symmetry
373 if (np.abs(op_cc - np.eye(3, dtype=int)) < 1e-10).all():
374 nG0 = np.prod(nG)
375 index_G = np.arange(nG0).reshape(nG)
376 phase_G = np.ones(nG)
377 else:
378 ik = self.bz2ibz_k[k]
379 kibz_c = self.ibzk_kc[ik]
380 index_G = np.zeros(nG, dtype=int)
381 phase_G = np.zeros(nG, dtype=complex)
383 kbz_c = np.dot(self.symmetry.op_scc[s], kibz_c)
384 cgpaw.symmetrize_return_index(index_G, phase_G, op_cc.copy(),
385 np.ascontiguousarray(kibz_c),
386 kbz_c)
387 return index_G, phase_G
389 def find_k_plus_q(self, q_c, kpts_k: Sequence[int] = None) -> list[int]:
390 """Find the indices of k+q for all kpoints in the Brillouin zone.
392 In case that k+q is outside the BZ, the k-point inside the BZ
393 corresponding to k+q is given.
395 Parameters
396 ----------
397 q_c: np.ndarray
398 Coordinates for the q-vector in units of the reciprocal
399 lattice vectors.
400 kpts_k:
401 Restrict search to specified k-points.
403 """
404 k_x = kpts_k
405 if k_x is None:
406 return self.find_k_plus_q(q_c, range(self.nbzkpts))
408 i_x = []
409 for k in k_x:
410 kpt_c = self.bzk_kc[k] + q_c
411 d_kc = kpt_c - self.bzk_kc
412 d_k = abs(d_kc - d_kc.round()).sum(1)
413 i = d_k.argmin()
414 if d_k[i] > 1e-8:
415 raise KPointError('Could not find k+q!')
416 i_x.append(i)
418 return i_x
420 def get_bz_q_points(self, first=False):
421 """Return the q=k1-k2. q-mesh is always Gamma-centered."""
422 shift_c = 0.5 * ((self.N_c + 1) % 2) / self.N_c
423 bzq_qc = monkhorst_pack(self.N_c) + shift_c
424 if first:
425 return to1bz(bzq_qc, self.symmetry.cell_cv)
426 else:
427 return bzq_qc
429 def get_ibz_q_points(self, bzq_qc, op_scc):
430 """Return ibz q points and the corresponding symmetry operations that
431 work for k-mesh as well."""
433 ibzq_qc_tmp = []
434 ibzq_qc_tmp.append(bzq_qc[-1])
435 weight_tmp = [0]
437 for i, op_cc in enumerate(op_scc):
438 if np.abs(op_cc - np.eye(3)).sum() < 1e-8:
439 identity_iop = i
440 break
442 ibzq_q_tmp = {}
443 iop_q = {}
444 timerev_q = {}
445 diff_qc = {}
447 for i in range(len(bzq_qc) - 1, -1, -1): # loop opposite to kpoint
448 try:
449 ibzk, iop, timerev, diff_c = self.find_ibzkpt(
450 op_scc, ibzq_qc_tmp, bzq_qc[i])
451 find = False
452 for ii, iop1 in enumerate(self.sym_k):
453 if iop1 == iop and self.time_reversal_k[ii] == timerev:
454 find = True
455 break
456 if not find:
457 raise ValueError('cant find k!')
459 ibzq_q_tmp[i] = ibzk
460 weight_tmp[ibzk] += 1.
461 iop_q[i] = iop
462 timerev_q[i] = timerev
463 diff_qc[i] = diff_c
464 except ValueError:
465 ibzq_qc_tmp.append(bzq_qc[i])
466 weight_tmp.append(1.)
467 ibzq_q_tmp[i] = len(ibzq_qc_tmp) - 1
468 iop_q[i] = identity_iop
469 timerev_q[i] = False
470 diff_qc[i] = np.zeros(3)
472 # reverse the order.
473 nq = len(ibzq_qc_tmp)
474 ibzq_qc = np.zeros((nq, 3))
475 ibzq_q = np.zeros(len(bzq_qc), dtype=int)
476 for i in range(nq):
477 ibzq_qc[i] = ibzq_qc_tmp[nq - i - 1]
478 for i in range(len(bzq_qc)):
479 ibzq_q[i] = nq - ibzq_q_tmp[i] - 1
480 self.q_weights = np.array(weight_tmp[::-1]) / len(bzq_qc)
481 return ibzq_qc, ibzq_q, iop_q, timerev_q, diff_qc
483 def find_ibzkpt(self, symrel, ibzk_kc, bzk_c):
484 """Find index in IBZ and related symmetry operations."""
485 find = False
486 ibzkpt = 0
487 iop = 0
488 timerev = False
490 for sign in (1, -1):
491 for ioptmp, op in enumerate(symrel):
492 for i, ibzk in enumerate(ibzk_kc):
493 diff_c = bzk_c - sign * np.dot(op, ibzk)
494 if (np.abs(diff_c - diff_c.round()) < 1e-8).all():
495 ibzkpt = i
496 iop = ioptmp
497 find = True
498 if sign == -1:
499 timerev = True
500 break
501 if find:
502 break
503 if find:
504 break
506 if not find:
507 raise ValueError('Cant find corresponding IBZ kpoint!')
508 return ibzkpt, iop, timerev, diff_c.round()
510 def where_is_q(self, q_c, bzq_qc):
511 """Find the index of q points in BZ."""
512 d_qc = q_c - bzq_qc
513 d_q = abs(d_qc - d_qc.round()).sum(1)
514 q = d_q.argmin()
515 if d_q[q] > 1e-8:
516 raise KPointError('Could not find q!')
517 return q
519 def get_count(self, rank=None):
520 """Return the number of ks-pairs which belong to a given rank."""
522 if rank is None:
523 rank = self.comm.rank
524 assert rank in range(self.comm.size)
525 mynk0 = self.nibzkpts // self.comm.size
526 mynk = mynk0
527 if rank >= self.rank0:
528 mynk += 1
529 return mynk
531 def get_offset(self, rank=None):
532 """Return the offset of the first ks-pair on a given rank."""
534 if rank is None:
535 rank = self.comm.rank
536 assert rank in range(self.comm.size)
537 mynk0 = self.nibzkpts // self.comm.size
538 k0 = rank * mynk0
539 if rank >= self.rank0:
540 k0 += rank - self.rank0
541 return k0
543 def get_rank_and_index(self, k):
544 """Find rank and local index of k-point/spin combination."""
546 rank, q = self.who_has(k)
547 return rank, q
549 def get_indices(self, rank=None):
550 """Return the global ks-pair indices which belong to a given rank."""
552 k1 = self.get_offset(rank)
553 k2 = k1 + self.get_count(rank)
554 return np.arange(k1, k2)
556 def who_has(self, k):
557 """Convert global index to rank information and local index."""
559 mynk0 = self.nibzkpts // self.comm.size
560 if k < mynk0 * self.rank0:
561 rank, q = divmod(k, mynk0)
562 else:
563 rank, q = divmod(k - mynk0 * self.rank0, mynk0 + 1)
564 rank += self.rank0
565 return rank, q
567 def write(self, writer):
568 writer.write('ibzkpts', self.ibzk_kc)
569 writer.write('bzkpts', self.bzk_kc)
570 writer.write('bz2ibz', self.bz2ibz_k)
571 writer.write('weights', self.weight_k)
572 writer.write('rotations', self.symmetry.op_scc)
573 writer.write('translations', self.symmetry.ft_sc)
574 writer.write('atommap', self.symmetry.a_sa)