Coverage for gpaw/wannier90.py: 84%
400 statements
« prev ^ index » next coverage.py v7.7.1, created at 2025-07-20 00:19 +0000
« prev ^ index » next coverage.py v7.7.1, created at 2025-07-20 00:19 +0000
1import numpy as np
2from gpaw.utilities.blas import gemmdot
3from gpaw.ibz2bz import (get_overlap, get_overlap_coefficients,
4 get_phase_shifted_overlap_coefficients,
5 IBZ2BZMaps)
6from gpaw.spinorbit import soc_eigenstates
9class Wannier90:
10 def __init__(self, calc, seed=None, bands=None, orbitals_ai=None,
11 spin=0, spinors=False):
13 if seed is None:
14 seed = calc.atoms.get_chemical_formula()
15 self.seed = seed
17 if bands is None:
18 bands = range(calc.get_number_of_bands())
19 self.bands = bands
21 Na = len(calc.atoms)
22 if orbitals_ai is None:
23 orbitals_ai = []
24 for ia in range(Na):
25 ni = 0
26 setup = calc.wfs.setups[ia]
27 for l, n in zip(setup.l_j, setup.n_j):
28 if not n == -1:
29 ni += 2 * l + 1
30 orbitals_ai.append(range(ni))
32 self.calc = calc
33 self.ibz2bz = IBZ2BZMaps.from_calculator(calc)
34 self.bands = bands
35 self.Nn = len(bands)
36 self.Na = Na
37 self.orbitals_ai = orbitals_ai
38 self.Nw = np.sum([len(orbitals_ai[ai]) for ai in range(Na)])
39 self.kpts_kc = calc.get_ibz_k_points()
40 self.Nk = len(self.kpts_kc)
41 self.spin = spin
42 self.spinors = spinors
44 if spinors:
45 # spinorbit.WaveFunctions.transform currently do not suppport
46 # transformation of wavefunctions, only projections.
47 # XXX: should be updated in the future
48 assert calc.wfs.kd.nbzkpts == calc.wfs.kd.nibzkpts
49 self.soc = soc_eigenstates(calc)
50 else:
51 self.soc = None
53 def write_input(self,
54 mp=None,
55 plot=False,
56 num_iter=100,
57 write_xyz=False,
58 write_rmn=False,
59 translate_home_cell=False,
60 dis_num_iter=200,
61 dis_froz_max=0.1,
62 dis_mix_ratio=0.5,
63 dis_win_min=None,
64 dis_win_max=None,
65 search_shells=None,
66 write_u_matrices=False):
67 calc = self.calc
68 seed = self.seed
69 bands = self.bands
70 orbitals_ai = self.orbitals_ai
71 spinors = self.spinors
73 if seed is None:
74 seed = calc.atoms.get_chemical_formula()
76 if bands is None:
77 bands = range(calc.get_number_of_bands())
79 Na = len(calc.atoms)
80 if orbitals_ai is None:
81 orbitals_ai = []
82 for ia in range(Na):
83 ni = 0
84 setup = calc.wfs.setups[ia]
85 for l, n in zip(setup.l_j, setup.n_j):
86 if not n == -1:
87 ni += 2 * l + 1
88 orbitals_ai.append(range(ni))
89 assert len(orbitals_ai) == Na
91 Nw = np.sum([len(orbitals_ai[ai]) for ai in range(Na)])
92 if spinors:
93 Nw *= 2
94 new_bands = []
95 for n in bands:
96 new_bands.append(2 * n)
97 new_bands.append(2 * n + 1)
98 bands = new_bands
100 f = open(seed + '.win', 'w')
102 pos_ac = calc.spos_ac
103 # pos_av = calc.atoms.get_positions()
104 # cell_cv = calc.atoms.get_cell()
105 # pos_ac = np.dot(pos_av, np.linalg.inv(cell_cv))
107 print('begin projections', file=f)
108 for ia, orbitals_i in enumerate(orbitals_ai):
109 setup = calc.wfs.setups[ia]
110 l_i = []
111 n_i = []
112 for n, l in zip(setup.n_j, setup.l_j):
113 if not n == -1:
114 l_i += (2 * l + 1) * [l]
115 n_i += (2 * l + 1) * [n]
116 r_c = pos_ac[ia]
117 for orb in orbitals_i:
118 l = l_i[orb]
119 n = n_i[orb]
120 print(f'f={r_c[0]:1.2f}, {r_c[1]:1.2f}, {r_c[2]:1.2f} : s ',
121 end='', file=f)
122 print(f'# n = {n}, l = {l}', file=f)
124 print('end projections', file=f)
125 print(file=f)
127 if spinors:
128 print('spinors = True', file=f)
129 else:
130 print('spinors = False', file=f)
131 if write_u_matrices:
132 print('write_u_matrices = True', file=f)
133 print('write_hr = True', file=f)
134 if write_xyz:
135 print('write_xyz = True', file=f)
136 if write_rmn:
137 print('write_tb = True', file=f)
138 print('write_rmn = True', file=f)
139 if translate_home_cell:
140 print('translate_home_cell = True', file=f)
141 print(file=f)
142 print('num_bands = %d' % len(bands), file=f)
144 if search_shells is not None:
145 print(f"search_shells = {search_shells}", file=f)
147 maxn = max(bands)
148 if maxn + 1 != len(bands):
149 diffn = maxn - len(bands)
150 print('exclude_bands : ', end='', file=f)
151 counter = 0
152 for n in range(maxn):
153 if n not in bands:
154 counter += 1
155 if counter != diffn + 1:
156 print('%d,' % (n + 1), sep='', end='', file=f)
157 else:
158 print('%d' % (n + 1), file=f)
159 print(file=f)
161 print('guiding_centres = True', file=f)
162 print('num_wann = %d' % Nw, file=f)
163 print('num_iter = %d' % num_iter, file=f)
164 print(file=f)
166 if len(bands) > Nw:
167 ef = calc.get_fermi_level()
168 print('fermi_energy = %2.3f' % ef, file=f)
169 if dis_froz_max is not None:
170 print('dis_froz_max = %2.3f' % (ef + dis_froz_max), file=f)
171 if dis_win_min is not None:
172 print('dis_win_min = %2.3f' % (ef + dis_win_min), file=f)
173 if dis_win_max is not None:
174 print('dis_win_max = %2.3f' % (ef + dis_win_max), file=f)
175 print('dis_num_iter = %d' % dis_num_iter, file=f)
176 print('dis_mix_ratio = %1.1f' % dis_mix_ratio, file=f)
177 print(file=f)
179 print('begin unit_cell_cart', file=f)
180 for cell_c in calc.atoms.cell:
181 print(f'{cell_c[0]:14.10f} {cell_c[1]:14.10f} {cell_c[2]:14.10f}',
182 file=f)
183 print('end unit_cell_cart', file=f)
184 print(file=f)
186 print('begin atoms_frac', file=f)
187 for atom, pos_c in zip(calc.atoms, pos_ac):
188 print(atom.symbol, end='', file=f)
189 print(f'{pos_c[0]:14.10f} {pos_c[1]:14.10f} {pos_c[2]:14.10f}',
190 file=f)
191 print('end atoms_frac', file=f)
192 print(file=f)
194 if plot:
195 print('wannier_plot = True', file=f)
196 print('wvfn_formatted = True', file=f)
197 print(file=f)
199 if mp is not None:
200 N_c = mp
201 else:
202 N_c = calc.wfs.kd.N_c
203 print('mp_grid =', N_c[0], N_c[1], N_c[2], file=f)
204 print(file=f)
205 print('begin kpoints', file=f)
207 for kpt in calc.get_bz_k_points():
208 print(f'{kpt[0]:14.10f} {kpt[1]:14.10f} {kpt[2]:14.10f}', file=f)
209 print('end kpoints', file=f)
211 f.close()
213 def write_projections(self):
214 calc = self.calc
215 seed = self.seed
216 spin = self.spin
217 orbitals_ai = self.orbitals_ai
218 soc = self.soc
220 if seed is None:
221 seed = calc.atoms.get_chemical_formula()
223 bands = get_bands(seed)
224 Nn = len(bands)
226 spinors = False
228 win_file = open(seed + '.win')
229 for line in win_file.readlines():
230 l_e = line.split()
231 if len(l_e) > 0:
232 if l_e[0] == 'spinors':
233 spinors = l_e[2]
234 if spinors in ['T', 'true', '1', 'True']:
235 spinors = True
236 else:
237 spinors = False
238 if l_e[0] == 'num_wann':
239 Nw = int(l_e[2])
240 if l_e[0] == 'mp_grid':
241 Nk = int(l_e[2]) * int(l_e[3]) * int(l_e[4])
242 assert Nk == len(calc.get_bz_k_points())
244 Na = len(calc.atoms)
245 if orbitals_ai is None:
246 orbitals_ai = []
247 for ia in range(Na):
248 ni = 0
249 setup = calc.wfs.setups[ia]
250 for l, n in zip(setup.l_j, setup.n_j):
251 if not n == -1:
252 ni += 2 * l + 1
253 orbitals_ai.append(range(ni))
254 assert len(orbitals_ai) == Na
256 if spinors:
257 new_orbitals_ai = []
258 for orbitals_i in orbitals_ai:
259 new_orbitals_i = []
260 for i in orbitals_i:
261 new_orbitals_i.append(2 * i)
262 new_orbitals_i.append(2 * i + 1)
263 new_orbitals_ai.append(new_orbitals_i)
264 orbitals_ai = new_orbitals_ai
266 Ni = 0
267 for orbitals_i in orbitals_ai:
268 Ni += len(orbitals_i)
269 assert Nw == Ni
271 f = open(seed + '.amn', 'w')
273 print('Kohn-Sham input generated from GPAW calculation', file=f)
274 print('%10d %6d %6d' % (Nn, Nk, Nw), file=f)
276 P_kni = np.zeros((Nk, Nn, Nw), complex)
277 for ik in range(Nk):
278 if spinors:
279 P_ani = soc[ik].P_amj
280 else:
281 P_ani = get_projections_in_bz(calc.wfs,
282 ik,
283 spin,
284 self.ibz2bz,
285 bcomm=None)
286 for i in range(Nw):
287 icount = 0
288 for ai in range(Na):
289 ni = len(orbitals_ai[ai])
290 P_ni = P_ani[ai][bands]
291 P_ni = P_ni[:, orbitals_ai[ai]]
292 P_kni[ik, :, icount:ni + icount] = P_ni.conj()
293 icount += ni
295 for ik in range(Nk):
296 for i in range(Nw):
297 for n in range(Nn):
298 P = P_kni[ik, n, i]
299 data = (n + 1, i + 1, ik + 1, P.real, P.imag)
300 print('%4d %4d %4d %18.12f %20.12f' % data, file=f)
302 f.close()
304 def write_eigenvalues(self):
305 calc = self.calc
306 seed = self.seed
307 spin = self.spin
308 soc = self.soc
310 bands = get_bands(seed)
312 f = open(seed + '.eig', 'w')
314 for ik in range(len(calc.get_bz_k_points())):
315 if soc is None:
316 ibzk = calc.wfs.kd.bz2ibz_k[ik] # IBZ k-point
317 e_n = calc.get_eigenvalues(kpt=ibzk, spin=spin)
318 else:
319 e_n = soc[ik].eig_m
320 for i, n in enumerate(bands):
321 data = (i + 1, ik + 1, e_n[n])
322 print('%5d %5d %14.6f' % data, file=f)
324 f.close()
326 def write_overlaps(self, less_memory=False):
327 calc = self.calc
328 seed = self.seed
329 spin = self.spin
330 soc = self.soc
331 ibz2bz = self.ibz2bz
333 if seed is None:
334 seed = calc.atoms.get_chemical_formula()
336 if soc is None:
337 spinors = False
338 else:
339 spinors = True
341 bands = get_bands(seed)
342 Nn = len(bands)
343 kpts_kc = calc.get_bz_k_points()
344 Nk = len(kpts_kc)
346 nnkp = open(seed + '.nnkp')
347 lines = nnkp.readlines()
348 for il, line in enumerate(lines):
349 if len(line.split()) > 1:
350 if line.split()[0] == 'begin' and line.split()[1] == 'nnkpts':
351 Nb = eval(lines[il + 1].split()[0])
352 i0 = il + 2
353 break
355 f = open(seed + '.mmn', 'w')
357 print('Kohn-Sham input generated from GPAW calculation', file=f)
358 print('%10d %6d %6d' % (Nn, Nk, Nb), file=f)
360 icell_cv = (2 * np.pi) * np.linalg.inv(calc.wfs.gd.cell_cv).T
361 r_g = calc.wfs.gd.get_grid_point_coordinates()
363 spos_ac = calc.spos_ac
364 wfs = calc.wfs
365 dO_aii = get_overlap_coefficients(wfs)
367 if not less_memory:
368 u_knG = []
369 for ik in range(Nk):
370 u_nG = self.wavefunctions(ik, bands)
371 u_knG.append(u_nG)
373 proj_k = []
374 for ik in range(Nk):
375 if spinors:
376 proj_k.append(soc[ik].projections)
377 else:
378 proj_k.append(get_projections_in_bz(calc.wfs,
379 ik, spin,
380 ibz2bz,
381 bcomm=None))
383 for ik1 in range(Nk):
384 if less_memory:
385 u1_nG = self.wavefunctions(ik1, bands)
386 else:
387 u1_nG = u_knG[ik1]
388 for ib in range(Nb):
389 # b denotes nearest neighbor k-points
390 line = lines[i0 + ik1 * Nb + ib].split()
391 ik2 = int(line[1]) - 1
392 if less_memory:
393 u2_nG = self.wavefunctions(ik2, bands)
394 else:
395 u2_nG = u_knG[ik2]
397 G_c = np.array([int(line[i]) for i in range(2, 5)])
398 bG_v = np.dot(G_c, icell_cv)
399 u2_nG = u2_nG * np.exp(-1.0j * gemmdot(bG_v, r_g, beta=0.0))
400 bG_c = kpts_kc[ik2] - kpts_kc[ik1] + G_c
401 phase_shifted_dO_aii = get_phase_shifted_overlap_coefficients(
402 dO_aii, spos_ac, -bG_c)
403 M_mm = get_overlap(bands,
404 wfs.gd,
405 u1_nG,
406 u2_nG,
407 proj_k[ik1],
408 proj_k[ik2],
409 phase_shifted_dO_aii)
410 indices = (ik1 + 1, ik2 + 1, G_c[0], G_c[1], G_c[2])
411 print('%3d %3d %4d %3d %3d' % indices, file=f)
412 for m1 in range(len(M_mm)):
413 for m2 in range(len(M_mm)):
414 M = M_mm[m2, m1]
415 print(f'{M.real:20.12f} {M.imag:20.12f}', file=f)
417 f.close()
419 def write_wavefunctions(self):
421 calc = self.calc
422 soc = self.soc
423 spin = self.spin
424 seed = self.seed
426 if soc is None:
427 spinors = False
428 else:
429 spinors = True
431 if seed is None:
432 seed = calc.atoms.get_chemical_formula()
434 bands = get_bands(seed)
435 Nn = len(bands)
436 Nk = len(calc.get_bz_k_points())
438 for ik in range(Nk):
439 if spinors:
440 # For spinors, G denotes spin and grid: G = (s, gx, gy, gz)
441 u_nG = soc[ik].wavefunctions(calc, periodic=True)
442 else:
443 # For non-spinors, G denotes grid: G = (gx, gy, gz)
444 u_nG = self.wavefunctions(ik, bands)
446 f = open('UNK%s.%d' % (str(ik + 1).zfill(5), spin + 1), 'w')
447 grid_v = np.shape(u_nG)[1:]
448 print(grid_v[0], grid_v[1], grid_v[2], ik + 1, Nn, file=f)
449 for n in bands:
450 for iz in range(grid_v[2]):
451 for iy in range(grid_v[1]):
452 for ix in range(grid_v[0]):
453 u = u_nG[n, ix, iy, iz]
454 print(u.real, u.imag, file=f)
455 f.close()
457 def wavefunctions(self, bz_index, bands):
458 maxband = bands[-1] + 1
459 if self.spinors:
460 # For spinors, G denotes spin and grid: G = (s, gx, gy, gz)
461 return self.soc[bz_index].wavefunctions(
462 self.calc, periodic=True)[bands]
463 # For non-spinors, G denotes grid: G = (gx, gy, gz)
464 ibz_index = self.calc.wfs.kd.bz2ibz_k[bz_index]
465 ut_nR = np.array([self.calc.wfs.get_wave_function_array(
466 n, ibz_index, self.spin,
467 periodic=True) for n in range(maxband)])
468 ut_nR_sym = np.array([self.ibz2bz[bz_index].map_pseudo_wave_to_BZ(
469 ut_nR[n]) for n in range(maxband)])
471 return ut_nR_sym
474def get_bands(seed):
475 win_file = open(seed + '.win')
476 exclude_bands = None
477 for line in win_file.readlines():
478 l_e = line.split()
479 if len(l_e) > 0:
480 if l_e[0] == 'num_bands':
481 Nn = int(l_e[2])
482 if l_e[0] == 'exclude_bands':
483 exclude_bands = line.split()[2]
484 exclude_bands = [int(n) - 1 for n in exclude_bands.split(',')]
485 if exclude_bands is None:
486 bands = range(Nn)
487 else:
488 bands = range(Nn + len(exclude_bands))
489 bands = [n for n in bands if n not in exclude_bands]
490 win_file.close()
492 return bands
495def get_projections_in_bz(wfs, K, s, ibz2bz, bcomm=None):
496 """ Returns projections object in full BZ
497 wfs: calc.wfs object
498 K: BZ k-point index
499 s: spin index
500 ibz2bz: IBZ2BZMaps
501 bcomm: band communicator
502 """
503 ik = wfs.kd.bz2ibz_k[K] # IBZ k-point
504 kpt = wfs.kpt_qs[ik][s]
505 nbands = wfs.bd.nbands
506 # Get projections in ibz
507 proj = kpt.projections.new(nbands=nbands, bcomm=bcomm)
508 proj.array[:] = kpt.projections.array[:nbands]
510 # map projections to bz
511 proj_sym = ibz2bz[K].map_projections(proj)
512 return proj_sym
515def read_umat(seed, kd, dis=False):
516 """
517 Reads wannier transformation matrix
518 """
519 if ".mat" not in seed:
520 if dis:
521 seed += "_u_dis.mat"
522 else:
523 seed += "_u.mat"
524 f = open(seed, "r")
525 f.readline() # first line is a comment
526 nk, nw1, nw2 = [int(i) for i in f.readline().split()]
527 assert nk == kd.nbzkpts
528 uwan = np.empty([nw1, nw2, nk], dtype=complex)
529 iklist = [] # list to store found iks
530 for ik1 in range(nk):
531 f.readline() # empty line
532 K_c = [float(rdum) for rdum in f.readline().split()]
533 ik = kd.where_is_q(K_c, kd.bzk_kc)
534 assert np.allclose(np.array(K_c), kd.bzk_kc[ik])
535 iklist.append(ik)
536 for ib1 in range(nw1):
537 for ib2 in range(nw2):
538 rdum1, rdum2 = [float(rdum) for rdum in
539 f.readline().split()]
540 uwan[ib1, ib2, ik] = complex(rdum1, rdum2)
541 assert set(iklist) == set(range(nk)) # check that all k:s were found
542 return uwan, nk, nw1, nw2
545def read_uwan(seed, kd, dis=False):
546 """
547 Reads wannier transformation matrix
548 Input parameters:
549 -----------------
550 seed: str
551 seed in wannier calculation
552 kd: kpt descriptor
553 dis: logical
554 should be set to true if nband > nwan
555 """
556 assert '.mat' not in seed
557 # reads in wannier transformation matrix
558 umat, nk, nw1, nw2 = read_umat(seed, kd, dis=False)
560 if dis:
561 # Reads in transformation to optimal subspace
562 umat_dis, nk, nw1, nw2 = read_umat(seed, kd, dis=True)
563 uwan = np.zeros_like(umat_dis)
564 for ik in range(nk):
565 uwan[:, :, ik] = umat[:, :, ik] @ umat_dis[:, :, ik]
566 else:
567 uwan = umat
568 return uwan, nk, nw1, nw2