Coverage for gpaw/wannier/w90.py: 96%
127 statements
« prev ^ index » next coverage.py v7.7.1, created at 2025-07-19 00:19 +0000
« prev ^ index » next coverage.py v7.7.1, created at 2025-07-19 00:19 +0000
1import subprocess
2from pathlib import Path
3from typing import Union, IO, Dict, Any, cast
5from ase import Atoms
6import numpy as np
8from .overlaps import WannierOverlaps
9from .functions import WannierFunctions
10from gpaw.typing import Array3D
13class Wannier90Error(Exception):
14 """Wannier90 error."""
17class Wannier90:
18 def __init__(self,
19 prefix: str = 'wannier',
20 folder: Union[str, Path] = 'W90',
21 executable='wannier90.x'):
22 self.prefix = prefix
23 self.folder = Path(folder)
24 self.executable = executable
25 self.folder.mkdir(exist_ok=True)
27 def run_wannier90(self, postprocess=False, world=None):
28 args = [self.executable, self.prefix]
29 if postprocess:
30 args[1:1] = ['-pp']
31 result = subprocess.run(args,
32 cwd=self.folder,
33 stdout=subprocess.PIPE)
34 if b'Error:' in result.stdout:
35 raise Wannier90Error(result.stdout.decode())
37 def write_input_files(self,
38 overlaps: WannierOverlaps,
39 **kwargs) -> None:
40 self.write_win(overlaps, **kwargs)
41 self.write_mmn(overlaps)
42 if overlaps.projections is not None:
43 self.write_amn(overlaps.projections)
45 def write_win(self,
46 overlaps: WannierOverlaps,
48 **kwargs) -> None:
49 kwargs['num_bands'] = overlaps.nbands
50 kwargs['num_wann'] = overlaps.nwannier
51 kwargs['fermi_energy'] = overlaps.fermi_level
52 kwargs['unit_cell_cart'] = overlaps.atoms.cell.tolist()
53 kwargs['atoms_frac'] = [[symbol] + list(spos_c)
54 for symbol, spos_c
55 in zip(overlaps.atoms.symbols,
56 overlaps.atoms.get_scaled_positions())]
57 kwargs['mp_grid'] = tuple(overlaps.monkhorst_pack_size)
58 kwargs['kpoints'] = overlaps.kpoints
59 if overlaps.proj_indices_a:
60 kwargs['guiding_centres'] = True
61 centers = []
62 for (x, y, z), indices in zip(overlaps.atoms.positions,
63 overlaps.proj_indices_a):
64 centers += [[f'c={x},{y},{z}: s']] * len(indices)
65 kwargs['projections'] = centers
67 with (self.folder / f'{self.prefix}.win').open('w') as fd:
68 for key, val in kwargs.items():
69 if isinstance(val, tuple):
70 print(f'{key} =', *val, file=fd)
71 elif isinstance(val, (list, np.ndarray)):
72 print(f'begin {key}', file=fd)
73 for line in val:
74 print(' ', *line, file=fd)
75 print(f'end {key}', file=fd)
76 else:
77 print(f'{key} = {val}', file=fd)
79 def write_mmn(self,
80 overlaps: WannierOverlaps) -> None:
81 size = overlaps.monkhorst_pack_size
82 nbzkpts = cast(int, np.prod(size))
83 nbands = overlaps.nbands
85 directions = list(overlaps.directions)
86 directions += [(-a, -b, -c) for (a, b, c) in directions]
87 ndirections = len(directions)
89 with (self.folder / f'{self.prefix}.mmn').open('w') as fd:
90 print('Input generated from GPAW', file=fd)
91 print(f'{nbands} {nbzkpts} {ndirections}', file=fd)
93 for bz_index1 in range(nbzkpts):
94 i1_c = np.unravel_index(bz_index1, size)
95 for direction in directions:
96 i2_c = np.array(i1_c) + direction
97 bz_index2 = np.ravel_multi_index(i2_c,
98 size,
99 'wrap') # type: ignore
100 d_c = (i2_c - i2_c % size) // size
101 print(bz_index1 + 1, bz_index2 + 1, *d_c, file=fd)
102 M_nn = overlaps.overlap(bz_index1, direction)
103 for M_n in M_nn.T:
104 for M in M_n:
105 print(f'{M.real} {M.imag}', file=fd)
107 def write_amn(self,
108 proj_kmn: Array3D) -> None:
109 nbzkpts, nproj, nbands = proj_kmn.shape
111 with (self.folder / f'{self.prefix}.amn').open('w') as fd:
112 print('Input generated from GPAW', file=fd)
113 print(f'{nbands} {nbzkpts} {nproj}', file=fd)
115 for bz_index, proj_mn in enumerate(proj_kmn):
116 for m, proj_n in enumerate(proj_mn):
117 for n, P in enumerate(proj_n):
118 print(n + 1, m + 1, bz_index + 1, P.real, -P.imag,
119 file=fd)
121 def read_result(self):
122 with (self.folder / f'{self.prefix}.wout').open() as fd:
123 w = read_wout_all(fd)
124 return Wannier90Functions(w['atoms'], w['centers'])
127class Wannier90Functions(WannierFunctions):
128 def __init__(self,
129 atoms: Atoms,
130 centers):
131 WannierFunctions.__init__(self, atoms, centers, 0.0, [])
134def read_wout_all(fileobj: IO[str]) -> Dict[str, Any]:
135 """Read atoms, wannier function centers and spreads."""
136 lines = fileobj.readlines()
138 for n, line in enumerate(lines):
139 if line.strip().lower().startswith('lattice vectors (ang)'):
140 break
141 else:
142 raise ValueError('Could not fine lattice vectors')
144 cell = [[float(x) for x in line.split()[-3:]]
145 for line in lines[n + 1:n + 4]]
147 for n, line in enumerate(lines):
148 if 'cartesian coordinate (ang)' in line.lower():
149 break
150 else:
151 raise ValueError('Could not find coordinates')
153 positions = []
154 symbols = []
155 n += 2
156 while True:
157 words = lines[n].split()
158 if len(words) == 1:
159 break
160 positions.append([float(x) for x in words[-4:-1]])
161 symbols.append(words[1])
162 n += 1
164 atoms = Atoms(symbols, positions, cell=cell, pbc=True)
166 n = len(lines) - 1
167 while n > 0:
168 if lines[n].strip().lower().startswith('final state'):
169 break
170 n -= 1
171 else:
172 return {'atoms': atoms,
173 'centers': np.zeros((0, 3)),
174 'spreads': np.zeros((0,))}
176 n += 1
177 centers = []
178 spreads = []
179 while True:
180 line = lines[n].strip()
181 if line.startswith('WF'):
182 centers.append([float(x)
183 for x in
184 line.split('(')[1].split(')')[0].split(',')])
185 spreads.append(float(line.split()[-1]))
186 n += 1
187 else:
188 break
190 return {'atoms': atoms,
191 'centers': np.array(centers),
192 'spreads': np.array(spreads)}