Coverage for gpaw/wannier/w90.py: 96%

127 statements  

« prev     ^ index     » next       coverage.py v7.7.1, created at 2025-07-19 00:19 +0000

1import subprocess 

2from pathlib import Path 

3from typing import Union, IO, Dict, Any, cast 

4 

5from ase import Atoms 

6import numpy as np 

7 

8from .overlaps import WannierOverlaps 

9from .functions import WannierFunctions 

10from gpaw.typing import Array3D 

11 

12 

13class Wannier90Error(Exception): 

14 """Wannier90 error.""" 

15 

16 

17class Wannier90: 

18 def __init__(self, 

19 prefix: str = 'wannier', 

20 folder: Union[str, Path] = 'W90', 

21 executable='wannier90.x'): 

22 self.prefix = prefix 

23 self.folder = Path(folder) 

24 self.executable = executable 

25 self.folder.mkdir(exist_ok=True) 

26 

27 def run_wannier90(self, postprocess=False, world=None): 

28 args = [self.executable, self.prefix] 

29 if postprocess: 

30 args[1:1] = ['-pp'] 

31 result = subprocess.run(args, 

32 cwd=self.folder, 

33 stdout=subprocess.PIPE) 

34 if b'Error:' in result.stdout: 

35 raise Wannier90Error(result.stdout.decode()) 

36 

37 def write_input_files(self, 

38 overlaps: WannierOverlaps, 

39 **kwargs) -> None: 

40 self.write_win(overlaps, **kwargs) 

41 self.write_mmn(overlaps) 

42 if overlaps.projections is not None: 

43 self.write_amn(overlaps.projections) 

44 

45 def write_win(self, 

46 overlaps: WannierOverlaps, 

47 

48 **kwargs) -> None: 

49 kwargs['num_bands'] = overlaps.nbands 

50 kwargs['num_wann'] = overlaps.nwannier 

51 kwargs['fermi_energy'] = overlaps.fermi_level 

52 kwargs['unit_cell_cart'] = overlaps.atoms.cell.tolist() 

53 kwargs['atoms_frac'] = [[symbol] + list(spos_c) 

54 for symbol, spos_c 

55 in zip(overlaps.atoms.symbols, 

56 overlaps.atoms.get_scaled_positions())] 

57 kwargs['mp_grid'] = tuple(overlaps.monkhorst_pack_size) 

58 kwargs['kpoints'] = overlaps.kpoints 

59 if overlaps.proj_indices_a: 

60 kwargs['guiding_centres'] = True 

61 centers = [] 

62 for (x, y, z), indices in zip(overlaps.atoms.positions, 

63 overlaps.proj_indices_a): 

64 centers += [[f'c={x},{y},{z}: s']] * len(indices) 

65 kwargs['projections'] = centers 

66 

67 with (self.folder / f'{self.prefix}.win').open('w') as fd: 

68 for key, val in kwargs.items(): 

69 if isinstance(val, tuple): 

70 print(f'{key} =', *val, file=fd) 

71 elif isinstance(val, (list, np.ndarray)): 

72 print(f'begin {key}', file=fd) 

73 for line in val: 

74 print(' ', *line, file=fd) 

75 print(f'end {key}', file=fd) 

76 else: 

77 print(f'{key} = {val}', file=fd) 

78 

79 def write_mmn(self, 

80 overlaps: WannierOverlaps) -> None: 

81 size = overlaps.monkhorst_pack_size 

82 nbzkpts = cast(int, np.prod(size)) 

83 nbands = overlaps.nbands 

84 

85 directions = list(overlaps.directions) 

86 directions += [(-a, -b, -c) for (a, b, c) in directions] 

87 ndirections = len(directions) 

88 

89 with (self.folder / f'{self.prefix}.mmn').open('w') as fd: 

90 print('Input generated from GPAW', file=fd) 

91 print(f'{nbands} {nbzkpts} {ndirections}', file=fd) 

92 

93 for bz_index1 in range(nbzkpts): 

94 i1_c = np.unravel_index(bz_index1, size) 

95 for direction in directions: 

96 i2_c = np.array(i1_c) + direction 

97 bz_index2 = np.ravel_multi_index(i2_c, 

98 size, 

99 'wrap') # type: ignore 

100 d_c = (i2_c - i2_c % size) // size 

101 print(bz_index1 + 1, bz_index2 + 1, *d_c, file=fd) 

102 M_nn = overlaps.overlap(bz_index1, direction) 

103 for M_n in M_nn.T: 

104 for M in M_n: 

105 print(f'{M.real} {M.imag}', file=fd) 

106 

107 def write_amn(self, 

108 proj_kmn: Array3D) -> None: 

109 nbzkpts, nproj, nbands = proj_kmn.shape 

110 

111 with (self.folder / f'{self.prefix}.amn').open('w') as fd: 

112 print('Input generated from GPAW', file=fd) 

113 print(f'{nbands} {nbzkpts} {nproj}', file=fd) 

114 

115 for bz_index, proj_mn in enumerate(proj_kmn): 

116 for m, proj_n in enumerate(proj_mn): 

117 for n, P in enumerate(proj_n): 

118 print(n + 1, m + 1, bz_index + 1, P.real, -P.imag, 

119 file=fd) 

120 

121 def read_result(self): 

122 with (self.folder / f'{self.prefix}.wout').open() as fd: 

123 w = read_wout_all(fd) 

124 return Wannier90Functions(w['atoms'], w['centers']) 

125 

126 

127class Wannier90Functions(WannierFunctions): 

128 def __init__(self, 

129 atoms: Atoms, 

130 centers): 

131 WannierFunctions.__init__(self, atoms, centers, 0.0, []) 

132 

133 

134def read_wout_all(fileobj: IO[str]) -> Dict[str, Any]: 

135 """Read atoms, wannier function centers and spreads.""" 

136 lines = fileobj.readlines() 

137 

138 for n, line in enumerate(lines): 

139 if line.strip().lower().startswith('lattice vectors (ang)'): 

140 break 

141 else: 

142 raise ValueError('Could not fine lattice vectors') 

143 

144 cell = [[float(x) for x in line.split()[-3:]] 

145 for line in lines[n + 1:n + 4]] 

146 

147 for n, line in enumerate(lines): 

148 if 'cartesian coordinate (ang)' in line.lower(): 

149 break 

150 else: 

151 raise ValueError('Could not find coordinates') 

152 

153 positions = [] 

154 symbols = [] 

155 n += 2 

156 while True: 

157 words = lines[n].split() 

158 if len(words) == 1: 

159 break 

160 positions.append([float(x) for x in words[-4:-1]]) 

161 symbols.append(words[1]) 

162 n += 1 

163 

164 atoms = Atoms(symbols, positions, cell=cell, pbc=True) 

165 

166 n = len(lines) - 1 

167 while n > 0: 

168 if lines[n].strip().lower().startswith('final state'): 

169 break 

170 n -= 1 

171 else: 

172 return {'atoms': atoms, 

173 'centers': np.zeros((0, 3)), 

174 'spreads': np.zeros((0,))} 

175 

176 n += 1 

177 centers = [] 

178 spreads = [] 

179 while True: 

180 line = lines[n].strip() 

181 if line.startswith('WF'): 

182 centers.append([float(x) 

183 for x in 

184 line.split('(')[1].split(')')[0].split(',')]) 

185 spreads.append(float(line.split()[-1])) 

186 n += 1 

187 else: 

188 break 

189 

190 return {'atoms': atoms, 

191 'centers': np.array(centers), 

192 'spreads': np.array(spreads)}