Coverage for gpaw/wannier/overlaps.py: 99%
165 statements
« prev ^ index » next coverage.py v7.7.1, created at 2025-07-08 00:17 +0000
« prev ^ index » next coverage.py v7.7.1, created at 2025-07-08 00:17 +0000
1from pathlib import Path
2from typing import Dict, List, Tuple, Union
4import numpy as np
5from ase import Atoms
6from ase.units import Bohr
8from gpaw.new.ase_interface import ASECalculator as GPAW
9from gpaw.kpt_descriptor import KPointDescriptor
10from gpaw.projections import Projections
11from gpaw.setup import Setup
12from gpaw.typing import Array2D, Array3D, Array4D, ArrayLike1D
13from gpaw.utilities.partition import AtomPartition
15from .functions import WannierFunctions
18class WannierOverlaps:
19 def __init__(self,
20 atoms: Atoms,
21 nwannier: int,
22 monkhorst_pack_size: ArrayLike1D,
23 kpoints: Array2D,
24 fermi_level: float,
25 directions: Dict[Tuple[int, ...], int],
26 overlaps: Array4D,
27 projections: Array3D = None,
28 proj_indices_a: List[List[int]] = None):
30 self.atoms = atoms
31 self.nwannier = nwannier
32 self.monkhorst_pack_size = np.array(monkhorst_pack_size)
33 self.kpoints = kpoints
34 self.fermi_level = fermi_level
35 self.directions = directions
37 self.nkpts, ndirs, self.nbands, nbands = overlaps.shape
38 assert nbands == self.nbands
39 assert self.nkpts == np.prod(monkhorst_pack_size) # type: ignore
40 assert ndirs == len(directions)
42 self._overlaps = overlaps
43 self.projections = projections
44 self.proj_indices_a = proj_indices_a
46 def overlap(self,
47 bz_index: int,
48 direction: Tuple[int, ...]) -> Array2D:
49 dindex = self.directions.get(direction)
50 if dindex is not None:
51 return self._overlaps[bz_index, dindex]
53 size = self.monkhorst_pack_size
54 i_c = np.unravel_index(bz_index, size)
55 i2_c = np.array(i_c) + direction
56 bz_index2 = np.ravel_multi_index(i2_c, size, 'wrap') # type: ignore
57 direction2 = tuple([-d for d in direction])
58 dindex2 = self.directions[direction2]
59 return self._overlaps[bz_index2, dindex2].T.conj()
61 def localize_er(self,
62 maxiter: int = 100,
63 tolerance: float = 1e-5,
64 verbose: bool = not False) -> WannierFunctions:
65 from .edmiston_ruedenberg import localize
66 return localize(self, maxiter, tolerance, verbose)
68 def localize_w90(self,
69 prefix: str = 'wannier',
70 folder: Union[Path, str] = 'W90',
71 nwannier: int = None,
72 **kwargs) -> WannierFunctions:
73 from .w90 import Wannier90
74 w90 = Wannier90(prefix, folder)
75 w90.write_input_files(overlaps=self, **kwargs)
76 w90.run_wannier90()
77 return w90.read_result()
80def dict_to_proj_indices(dct: Dict[Union[int, str], str],
81 setups: List[Setup]) -> List[List[int]]:
82 """Convert dict to lists of projector function indices.
84 >>> from gpaw.setup import create_setup
85 >>> setup = create_setup('Si') # 3s, 3p, *s, *p, *d
86 >>> setup.n_j
87 [3, 3, -1, -1, -1]
88 >>> setup.l_j
89 [0, 1, 0, 1, 2]
90 >>> dict_to_proj_indices({'Si': 'sp', 1: 's'}, [setup, setup])
91 [[0, 1, 2, 3], [0]]
92 """
93 indices_a = []
94 for a, setup in enumerate(setups):
95 ll = dct.get(a, dct.get(setup.symbol, ''))
96 indices = []
97 i = 0
98 for n, l in zip(setup.n_j, setup.l_j):
99 if n > 0 and 'spdf'[l] in ll:
100 indices += list(range(i, i + 2 * l + 1))
101 i += 2 * l + 1
102 indices_a.append(indices)
103 return indices_a
106def calculate_overlaps(calc: GPAW,
107 nwannier: int,
108 projections: Dict[Union[int, str], str] = None,
109 n1: int = 0,
110 n2: int = 0,
111 spinors: bool = False,
112 spin: int = 0) -> WannierOverlaps:
113 """Create WannierOverlaps object from DFT calculation.
114 """
115 assert not spinors
117 if n2 <= 0:
118 n2 += calc.get_number_of_bands()
120 bzwfs = BZRealSpaceWaveFunctions.from_calculation(calc, n1, n2, spin)
122 proj_indices_a = dict_to_proj_indices(projections or {},
123 calc.setups)
125 offsets = [0]
126 for indices in proj_indices_a:
127 offsets.append(offsets[-1] + len(indices))
128 nproj = offsets.pop()
130 if projections is not None:
131 assert nproj == nwannier
133 kd = bzwfs.kd
134 gd = bzwfs.gd
135 size = kd.N_c
136 assert size is not None
138 icell = calc.atoms.cell.reciprocal()
139 directions = {direction: i
140 for i, direction
141 in enumerate(find_directions(icell, size))}
142 Z_kdnn = np.empty((kd.nbzkpts, len(directions), n2 - n1, n2 - n1), complex)
144 spos_ac = calc.spos_ac
145 setups = calc.wfs.setups
147 proj_kmn = np.zeros((kd.nbzkpts, nproj, n2 - n1), complex)
149 for bz_index1 in range(kd.nbzkpts):
150 wf1 = bzwfs[bz_index1]
151 i1_c = np.unravel_index(bz_index1, size)
152 for direction, d in directions.items():
153 i2_c = np.array(i1_c) + direction
154 bz_index2 = np.ravel_multi_index(i2_c,
155 size,
156 'wrap') # type: ignore
157 wf2 = bzwfs[bz_index2]
158 phase_c = (i2_c % size - i2_c) // size
159 u2_nR = wf2.u_nR
160 if phase_c.any():
161 u2_nR = u2_nR * gd.plane_wave(phase_c)
162 Z_kdnn[bz_index1, d] = gd.integrate(wf1.u_nR, u2_nR,
163 global_integral=False)
165 for a, P1_ni in wf1.projections.items():
166 dO_ii = setups[a].dO_ii
167 P2_ni = wf2.projections[a]
168 Z_nn = P1_ni.conj().dot(dO_ii).dot(P2_ni.T).astype(complex)
169 if phase_c.any():
170 Z_nn *= np.exp(2j * np.pi * phase_c.dot(spos_ac[a]))
171 Z_kdnn[bz_index1, d] += Z_nn
173 for a, P1_ni in wf1.projections.items():
174 indices = proj_indices_a[a]
175 m = offsets[a]
176 proj_kmn[bz_index1, m:m + len(indices)] = P1_ni.T[indices]
178 gd.comm.sum(Z_kdnn)
179 gd.comm.sum(proj_kmn)
181 overlaps = WannierOverlaps(calc.atoms,
182 nwannier,
183 size,
184 kd.bzk_kc,
185 calc.get_fermi_level(),
186 directions,
187 Z_kdnn,
188 proj_kmn,
189 proj_indices_a)
190 return overlaps
193def find_directions(icell: Array2D,
194 mpsize: ArrayLike1D) -> List[Tuple[int, ...]]:
195 """Find nearest neighbors k-points.
197 icell:
198 Reciprocal cell.
199 mpsize:
200 Size of Monkhorst-Pack grid.
202 If dk is a vector pointing at a neighbor k-points then we don't
203 also include -dk in the list. Examples: for simple cubic there
204 will be 3 neighbors and for FCC there will be 6.
206 For a hexagonal cell you get three directions in plane and one
207 out of plane:
209 >>> hex = np.array([[1, 0, 0], [0.5, 3**0.5 / 2, 0], [0, 0, 1]])
210 >>> dirs = find_directions(hex, (4, 4, 4))
211 >>> sorted(dirs)
212 [(0, 0, 1), (0, 1, 0), (1, -1, 0), (1, 0, 0)]
213 """
215 from scipy.spatial import Voronoi
217 d_ic = np.indices((3, 3, 3)).reshape((3, -1)).T - 1
218 d_iv = d_ic.dot((icell.T / mpsize).T)
219 voro = Voronoi(d_iv)
220 directions: List[Tuple[int, ...]] = []
221 for i1, i2 in voro.ridge_points:
222 if i1 == 13 and i2 > 13:
223 directions.append(tuple(d_ic[i2].tolist()))
224 elif i2 == 13 and i1 > 13:
225 directions.append(tuple(d_ic[i1].tolist()))
226 return directions
229class WaveFunction:
230 def __init__(self,
231 u_nR,
232 projections: Projections):
233 self.u_nR = u_nR
234 self.projections = projections
236 def redistribute_atoms(self,
237 gd,
238 atom_partition: AtomPartition
239 ) -> 'WaveFunction':
240 projections = self.projections.redist(atom_partition)
241 u_nR = gd.distribute(self.u_nR)
242 return WaveFunction(u_nR, projections)
245class BZRealSpaceWaveFunctions:
246 """Container for wave-functions and PAW projections (all of BZ)."""
247 def __init__(self,
248 kd: KPointDescriptor,
249 gd,
250 wfs: Dict[int, WaveFunction]):
251 self.kd = kd
252 self.gd = gd
253 self.wfs = wfs
255 def __getitem__(self, bz_index):
256 return self.wfs[bz_index]
258 @classmethod
259 def from_calculation(cls,
260 calc: GPAW,
261 n1: int = 0,
262 n2: int = 0,
263 spin=0) -> 'BZRealSpaceWaveFunctions':
264 wfs = calc.wfs
265 kd = wfs.kd
267 if wfs.mode == 'lcao' and not wfs.positions_set:
268 calc.initialize_positions()
270 gd = wfs.gd.new_descriptor(comm=calc.world)
272 nproj_a = wfs.kpt_qs[0][0].projections.nproj_a
273 # All atoms on rank-0:
274 rank_a = np.zeros_like(nproj_a)
275 atom_partition = AtomPartition(gd.comm, rank_a)
277 rank_a = np.arange(len(rank_a)) % gd.comm.size
278 atom_partition2 = AtomPartition(gd.comm, rank_a)
280 u_nR = gd.empty((n2 - n1), complex, global_array=True)
282 bzwfs = {}
283 for ibz_index in range(kd.nibzkpts):
284 for n in range(n1, n2):
285 u_nR[n - n1] = calc.get_pseudo_wave_function(
286 band=n,
287 kpt=ibz_index,
288 spin=spin,
289 periodic=True,
290 pad=False) * Bohr**1.5
291 P_nI = wfs.collect_projections(ibz_index, spin)
292 if P_nI is not None:
293 P_nI = P_nI[n1:n2]
294 projections = Projections(
295 nbands=n2 - n1,
296 nproj_a=nproj_a,
297 atom_partition=atom_partition,
298 data=P_nI)
300 wf = WaveFunction(u_nR.copy(), projections)
302 for bz_index, ibz_index2 in enumerate(kd.bz2ibz_k):
303 if ibz_index2 != ibz_index:
304 continue
305 if kd.ibz2bz_k[ibz_index] == bz_index:
306 wf1 = wf
307 else:
308 # One could potentially use the IBZ2BZMap functionality
309 # to transform the wave function in the future
310 raise NotImplementedError()
312 bzwfs[bz_index] = wf1.redistribute_atoms(gd, atom_partition2)
314 return BZRealSpaceWaveFunctions(kd, gd, bzwfs)