Coverage for gpaw/wannier/overlaps.py: 99%

165 statements  

« prev     ^ index     » next       coverage.py v7.7.1, created at 2025-07-08 00:17 +0000

1from pathlib import Path 

2from typing import Dict, List, Tuple, Union 

3 

4import numpy as np 

5from ase import Atoms 

6from ase.units import Bohr 

7 

8from gpaw.new.ase_interface import ASECalculator as GPAW 

9from gpaw.kpt_descriptor import KPointDescriptor 

10from gpaw.projections import Projections 

11from gpaw.setup import Setup 

12from gpaw.typing import Array2D, Array3D, Array4D, ArrayLike1D 

13from gpaw.utilities.partition import AtomPartition 

14 

15from .functions import WannierFunctions 

16 

17 

18class WannierOverlaps: 

19 def __init__(self, 

20 atoms: Atoms, 

21 nwannier: int, 

22 monkhorst_pack_size: ArrayLike1D, 

23 kpoints: Array2D, 

24 fermi_level: float, 

25 directions: Dict[Tuple[int, ...], int], 

26 overlaps: Array4D, 

27 projections: Array3D = None, 

28 proj_indices_a: List[List[int]] = None): 

29 

30 self.atoms = atoms 

31 self.nwannier = nwannier 

32 self.monkhorst_pack_size = np.array(monkhorst_pack_size) 

33 self.kpoints = kpoints 

34 self.fermi_level = fermi_level 

35 self.directions = directions 

36 

37 self.nkpts, ndirs, self.nbands, nbands = overlaps.shape 

38 assert nbands == self.nbands 

39 assert self.nkpts == np.prod(monkhorst_pack_size) # type: ignore 

40 assert ndirs == len(directions) 

41 

42 self._overlaps = overlaps 

43 self.projections = projections 

44 self.proj_indices_a = proj_indices_a 

45 

46 def overlap(self, 

47 bz_index: int, 

48 direction: Tuple[int, ...]) -> Array2D: 

49 dindex = self.directions.get(direction) 

50 if dindex is not None: 

51 return self._overlaps[bz_index, dindex] 

52 

53 size = self.monkhorst_pack_size 

54 i_c = np.unravel_index(bz_index, size) 

55 i2_c = np.array(i_c) + direction 

56 bz_index2 = np.ravel_multi_index(i2_c, size, 'wrap') # type: ignore 

57 direction2 = tuple([-d for d in direction]) 

58 dindex2 = self.directions[direction2] 

59 return self._overlaps[bz_index2, dindex2].T.conj() 

60 

61 def localize_er(self, 

62 maxiter: int = 100, 

63 tolerance: float = 1e-5, 

64 verbose: bool = not False) -> WannierFunctions: 

65 from .edmiston_ruedenberg import localize 

66 return localize(self, maxiter, tolerance, verbose) 

67 

68 def localize_w90(self, 

69 prefix: str = 'wannier', 

70 folder: Union[Path, str] = 'W90', 

71 nwannier: int = None, 

72 **kwargs) -> WannierFunctions: 

73 from .w90 import Wannier90 

74 w90 = Wannier90(prefix, folder) 

75 w90.write_input_files(overlaps=self, **kwargs) 

76 w90.run_wannier90() 

77 return w90.read_result() 

78 

79 

80def dict_to_proj_indices(dct: Dict[Union[int, str], str], 

81 setups: List[Setup]) -> List[List[int]]: 

82 """Convert dict to lists of projector function indices. 

83 

84 >>> from gpaw.setup import create_setup 

85 >>> setup = create_setup('Si') # 3s, 3p, *s, *p, *d 

86 >>> setup.n_j 

87 [3, 3, -1, -1, -1] 

88 >>> setup.l_j 

89 [0, 1, 0, 1, 2] 

90 >>> dict_to_proj_indices({'Si': 'sp', 1: 's'}, [setup, setup]) 

91 [[0, 1, 2, 3], [0]] 

92 """ 

93 indices_a = [] 

94 for a, setup in enumerate(setups): 

95 ll = dct.get(a, dct.get(setup.symbol, '')) 

96 indices = [] 

97 i = 0 

98 for n, l in zip(setup.n_j, setup.l_j): 

99 if n > 0 and 'spdf'[l] in ll: 

100 indices += list(range(i, i + 2 * l + 1)) 

101 i += 2 * l + 1 

102 indices_a.append(indices) 

103 return indices_a 

104 

105 

106def calculate_overlaps(calc: GPAW, 

107 nwannier: int, 

108 projections: Dict[Union[int, str], str] = None, 

109 n1: int = 0, 

110 n2: int = 0, 

111 spinors: bool = False, 

112 spin: int = 0) -> WannierOverlaps: 

113 """Create WannierOverlaps object from DFT calculation. 

114 """ 

115 assert not spinors 

116 

117 if n2 <= 0: 

118 n2 += calc.get_number_of_bands() 

119 

120 bzwfs = BZRealSpaceWaveFunctions.from_calculation(calc, n1, n2, spin) 

121 

122 proj_indices_a = dict_to_proj_indices(projections or {}, 

123 calc.setups) 

124 

125 offsets = [0] 

126 for indices in proj_indices_a: 

127 offsets.append(offsets[-1] + len(indices)) 

128 nproj = offsets.pop() 

129 

130 if projections is not None: 

131 assert nproj == nwannier 

132 

133 kd = bzwfs.kd 

134 gd = bzwfs.gd 

135 size = kd.N_c 

136 assert size is not None 

137 

138 icell = calc.atoms.cell.reciprocal() 

139 directions = {direction: i 

140 for i, direction 

141 in enumerate(find_directions(icell, size))} 

142 Z_kdnn = np.empty((kd.nbzkpts, len(directions), n2 - n1, n2 - n1), complex) 

143 

144 spos_ac = calc.spos_ac 

145 setups = calc.wfs.setups 

146 

147 proj_kmn = np.zeros((kd.nbzkpts, nproj, n2 - n1), complex) 

148 

149 for bz_index1 in range(kd.nbzkpts): 

150 wf1 = bzwfs[bz_index1] 

151 i1_c = np.unravel_index(bz_index1, size) 

152 for direction, d in directions.items(): 

153 i2_c = np.array(i1_c) + direction 

154 bz_index2 = np.ravel_multi_index(i2_c, 

155 size, 

156 'wrap') # type: ignore 

157 wf2 = bzwfs[bz_index2] 

158 phase_c = (i2_c % size - i2_c) // size 

159 u2_nR = wf2.u_nR 

160 if phase_c.any(): 

161 u2_nR = u2_nR * gd.plane_wave(phase_c) 

162 Z_kdnn[bz_index1, d] = gd.integrate(wf1.u_nR, u2_nR, 

163 global_integral=False) 

164 

165 for a, P1_ni in wf1.projections.items(): 

166 dO_ii = setups[a].dO_ii 

167 P2_ni = wf2.projections[a] 

168 Z_nn = P1_ni.conj().dot(dO_ii).dot(P2_ni.T).astype(complex) 

169 if phase_c.any(): 

170 Z_nn *= np.exp(2j * np.pi * phase_c.dot(spos_ac[a])) 

171 Z_kdnn[bz_index1, d] += Z_nn 

172 

173 for a, P1_ni in wf1.projections.items(): 

174 indices = proj_indices_a[a] 

175 m = offsets[a] 

176 proj_kmn[bz_index1, m:m + len(indices)] = P1_ni.T[indices] 

177 

178 gd.comm.sum(Z_kdnn) 

179 gd.comm.sum(proj_kmn) 

180 

181 overlaps = WannierOverlaps(calc.atoms, 

182 nwannier, 

183 size, 

184 kd.bzk_kc, 

185 calc.get_fermi_level(), 

186 directions, 

187 Z_kdnn, 

188 proj_kmn, 

189 proj_indices_a) 

190 return overlaps 

191 

192 

193def find_directions(icell: Array2D, 

194 mpsize: ArrayLike1D) -> List[Tuple[int, ...]]: 

195 """Find nearest neighbors k-points. 

196 

197 icell: 

198 Reciprocal cell. 

199 mpsize: 

200 Size of Monkhorst-Pack grid. 

201 

202 If dk is a vector pointing at a neighbor k-points then we don't 

203 also include -dk in the list. Examples: for simple cubic there 

204 will be 3 neighbors and for FCC there will be 6. 

205 

206 For a hexagonal cell you get three directions in plane and one 

207 out of plane: 

208 

209 >>> hex = np.array([[1, 0, 0], [0.5, 3**0.5 / 2, 0], [0, 0, 1]]) 

210 >>> dirs = find_directions(hex, (4, 4, 4)) 

211 >>> sorted(dirs) 

212 [(0, 0, 1), (0, 1, 0), (1, -1, 0), (1, 0, 0)] 

213 """ 

214 

215 from scipy.spatial import Voronoi 

216 

217 d_ic = np.indices((3, 3, 3)).reshape((3, -1)).T - 1 

218 d_iv = d_ic.dot((icell.T / mpsize).T) 

219 voro = Voronoi(d_iv) 

220 directions: List[Tuple[int, ...]] = [] 

221 for i1, i2 in voro.ridge_points: 

222 if i1 == 13 and i2 > 13: 

223 directions.append(tuple(d_ic[i2].tolist())) 

224 elif i2 == 13 and i1 > 13: 

225 directions.append(tuple(d_ic[i1].tolist())) 

226 return directions 

227 

228 

229class WaveFunction: 

230 def __init__(self, 

231 u_nR, 

232 projections: Projections): 

233 self.u_nR = u_nR 

234 self.projections = projections 

235 

236 def redistribute_atoms(self, 

237 gd, 

238 atom_partition: AtomPartition 

239 ) -> 'WaveFunction': 

240 projections = self.projections.redist(atom_partition) 

241 u_nR = gd.distribute(self.u_nR) 

242 return WaveFunction(u_nR, projections) 

243 

244 

245class BZRealSpaceWaveFunctions: 

246 """Container for wave-functions and PAW projections (all of BZ).""" 

247 def __init__(self, 

248 kd: KPointDescriptor, 

249 gd, 

250 wfs: Dict[int, WaveFunction]): 

251 self.kd = kd 

252 self.gd = gd 

253 self.wfs = wfs 

254 

255 def __getitem__(self, bz_index): 

256 return self.wfs[bz_index] 

257 

258 @classmethod 

259 def from_calculation(cls, 

260 calc: GPAW, 

261 n1: int = 0, 

262 n2: int = 0, 

263 spin=0) -> 'BZRealSpaceWaveFunctions': 

264 wfs = calc.wfs 

265 kd = wfs.kd 

266 

267 if wfs.mode == 'lcao' and not wfs.positions_set: 

268 calc.initialize_positions() 

269 

270 gd = wfs.gd.new_descriptor(comm=calc.world) 

271 

272 nproj_a = wfs.kpt_qs[0][0].projections.nproj_a 

273 # All atoms on rank-0: 

274 rank_a = np.zeros_like(nproj_a) 

275 atom_partition = AtomPartition(gd.comm, rank_a) 

276 

277 rank_a = np.arange(len(rank_a)) % gd.comm.size 

278 atom_partition2 = AtomPartition(gd.comm, rank_a) 

279 

280 u_nR = gd.empty((n2 - n1), complex, global_array=True) 

281 

282 bzwfs = {} 

283 for ibz_index in range(kd.nibzkpts): 

284 for n in range(n1, n2): 

285 u_nR[n - n1] = calc.get_pseudo_wave_function( 

286 band=n, 

287 kpt=ibz_index, 

288 spin=spin, 

289 periodic=True, 

290 pad=False) * Bohr**1.5 

291 P_nI = wfs.collect_projections(ibz_index, spin) 

292 if P_nI is not None: 

293 P_nI = P_nI[n1:n2] 

294 projections = Projections( 

295 nbands=n2 - n1, 

296 nproj_a=nproj_a, 

297 atom_partition=atom_partition, 

298 data=P_nI) 

299 

300 wf = WaveFunction(u_nR.copy(), projections) 

301 

302 for bz_index, ibz_index2 in enumerate(kd.bz2ibz_k): 

303 if ibz_index2 != ibz_index: 

304 continue 

305 if kd.ibz2bz_k[ibz_index] == bz_index: 

306 wf1 = wf 

307 else: 

308 # One could potentially use the IBZ2BZMap functionality 

309 # to transform the wave function in the future 

310 raise NotImplementedError() 

311 

312 bzwfs[bz_index] = wf1.redistribute_atoms(gd, atom_partition2) 

313 

314 return BZRealSpaceWaveFunctions(kd, gd, bzwfs)