Coverage for gpaw/lcaotddft/tcm.py: 10%

146 statements  

« prev     ^ index     » next       coverage.py v7.7.1, created at 2025-07-09 00:21 +0000

1import numpy as np 

2 

3 

4def generate_gridspec(**kwargs): 

5 from matplotlib.gridspec import GridSpec 

6 width = 0.84 

7 bottom = 0.12 

8 left = 0.12 

9 return GridSpec(2, 2, width_ratios=[3, 1], height_ratios=[1, 3], 

10 bottom=bottom, top=bottom + width, 

11 left=left, right=left + width, 

12 **kwargs) 

13 

14 

15def plot_DOS(ax, energy_e, dos_e, base_e, dos_min, dos_max, 

16 flip=False, fill=None, line=None): 

17 ax.xaxis.set_ticklabels([]) 

18 ax.yaxis.set_ticklabels([]) 

19 ax.spines['right'].set_visible(False) 

20 ax.spines['top'].set_visible(False) 

21 ax.yaxis.set_ticks_position('left') 

22 ax.xaxis.set_ticks_position('bottom') 

23 if flip: 

24 set_label = ax.set_xlabel 

25 fill_between = ax.fill_betweenx 

26 set_energy_lim = ax.set_ylim 

27 set_dos_lim = ax.set_xlim 

28 

29 def plot(x, y, *args, **kwargs): 

30 return ax.plot(y, x, *args, **kwargs) 

31 else: 

32 set_label = ax.set_ylabel 

33 fill_between = ax.fill_between 

34 set_energy_lim = ax.set_xlim 

35 set_dos_lim = ax.set_ylim 

36 

37 def plot(x, y, *args, **kwargs): 

38 return ax.plot(x, y, *args, **kwargs) 

39 if fill: 

40 fill_between(energy_e, base_e, dos_e + base_e, **fill) 

41 if line: 

42 plot(energy_e, dos_e, **line) 

43 set_label('DOS', labelpad=0) 

44 set_energy_lim(np.take(energy_e, (0, -1))) 

45 set_dos_lim(dos_min, dos_max) 

46 

47 

48class TCM: 

49 

50 def __init__(self, energy_o, energy_u, fermilevel): 

51 self.energy_o = energy_o 

52 self.energy_u = energy_u 

53 self.fermilevel = fermilevel 

54 

55 self.base_o = np.zeros_like(energy_o) 

56 self.base_u = np.zeros_like(energy_u) 

57 

58 def __getattr__(self, attr): 

59 import matplotlib.pyplot as plt 

60 # Generate axis only when needed 

61 if attr in ['ax_occ_dos', 'ax_unocc_dos', 'ax_tcm']: 

62 gs = generate_gridspec(hspace=0.05, wspace=0.05) 

63 self.ax_occ_dos = plt.subplot(gs[0]) 

64 self.ax_unocc_dos = plt.subplot(gs[3]) 

65 self.ax_tcm = plt.subplot(gs[2]) 

66 return getattr(self, attr) 

67 if attr in ['ax_spec']: 

68 gs = generate_gridspec(hspace=0.8, wspace=0.8) 

69 self.ax_spec = plt.subplot(gs[1]) 

70 return getattr(self, attr) 

71 if attr in ['ax_cbar']: 

72 self.ax_cbar = plt.axes((0.15, 0.6, 0.02, 0.1)) 

73 return getattr(self, attr) 

74 raise AttributeError('%s object has no attribute %s' % 

75 (repr(self.__class__.__name__), repr(attr))) 

76 

77 def plot_TCM(self, tcm_ou, vmax='80%', vmin='symmetrize', cmap='seismic', 

78 log=False, colorbar=True, lw=None): 

79 import matplotlib as mpl 

80 import matplotlib.pyplot as plt 

81 from matplotlib.colors import Normalize, LogNorm 

82 if lw is None: 

83 lw = mpl.rcParams['lines.linewidth'] 

84 energy_o = self.energy_o 

85 energy_u = self.energy_u 

86 fermilevel = self.fermilevel 

87 

88 tcmmax = np.max(np.absolute(tcm_ou)) 

89 print('tcmmax', tcmmax) 

90 

91 # Plot TCM 

92 ax = self.ax_tcm 

93 plt.sca(ax) 

94 plt.cla() 

95 if isinstance(vmax, str): 

96 assert vmax[-1] == '%' 

97 tcmmax = np.max(np.absolute(tcm_ou)) 

98 vmax = tcmmax * float(vmax[:-1]) / 100.0 

99 if vmin == 'symmetrize': 

100 vmin = -vmax 

101 if tcm_ou.dtype == complex: 

102 linecolor = 'w' 

103 from matplotlib.colors import hsv_to_rgb 

104 

105 def transform_to_hsv(z, rmin, rmax, hue_start=90): 

106 amp = np.absolute(z) # **2 

107 amp = np.where(amp < rmin, rmin, amp) 

108 amp = np.where(amp > rmax, rmax, amp) 

109 ph = np.angle(z, deg=1) + hue_start 

110 h = (ph % 360) / 360 

111 s = 1.85 * np.ones_like(h) 

112 v = (amp - rmin) / (rmax - rmin) 

113 return hsv_to_rgb(np.dstack((h, s, v))) 

114 

115 img = transform_to_hsv(tcm_ou.T, 0, vmax) 

116 plt.imshow(img, origin='lower', 

117 extent=[energy_o[0], energy_o[-1], 

118 energy_u[0], energy_u[-1]], 

119 interpolation='bilinear', 

120 ) 

121 else: 

122 linecolor = 'k' 

123 if cmap == 'magma': 

124 linecolor = 'w' 

125 if log: 

126 norm = LogNorm(vmin=vmin, vmax=vmax) 

127 else: 

128 norm = Normalize(vmin=vmin, vmax=vmax) 

129 plt.pcolormesh(energy_o, energy_u, tcm_ou.T, 

130 cmap=cmap, rasterized=True, norm=norm, 

131 shading='auto') 

132 if colorbar: 

133 ax = self.ax_cbar 

134 ax.clear() 

135 cb = plt.colorbar(cax=ax) 

136 cb.outline.set_edgecolor(linecolor) 

137 ax.tick_params(axis='both', colors=linecolor) 

138 # ax.yaxis.label.set_color(linecolor) 

139 # ax.xaxis.label.set_color(linecolor) 

140 ax = self.ax_tcm 

141 plt.sca(ax) 

142 plt.axhline(fermilevel, c=linecolor, lw=lw) 

143 plt.axvline(fermilevel, c=linecolor, lw=lw) 

144 

145 ax.tick_params(axis='both', which='major', pad=2) 

146 plt.xlabel(r'Occ. energy $\varepsilon_{o}$ (eV)', labelpad=0) 

147 plt.ylabel(r'Unocc. energy $\varepsilon_{u}$ (eV)', labelpad=0) 

148 plt.xlim(np.take(energy_o, (0, -1))) 

149 plt.ylim(np.take(energy_u, (0, -1))) 

150 

151 def plot_DOS(self, dos_o, dos_u, stack=False, 

152 fill={'color': '0.8'}, line={'color': 'k'}): 

153 # Plot DOSes 

154 if stack: 

155 base_o = self.base_o 

156 base_u = self.base_u 

157 else: 

158 base_o = np.zeros_like(self.energy_o) 

159 base_u = np.zeros_like(self.energy_u) 

160 dos_min = 0.0 

161 dos_max = 1.01 * max(np.max(dos_o), np.max(dos_u)) 

162 plot_DOS(self.ax_occ_dos, self.energy_o, dos_o, base_o, 

163 dos_min, dos_max, flip=False, fill=fill, line=line) 

164 plot_DOS(self.ax_unocc_dos, self.energy_u, dos_u, base_u, 

165 dos_min, dos_max, flip=True, fill=fill, line=line) 

166 if stack: 

167 self.base_o += dos_o 

168 self.base_u += dos_u 

169 

170 def plot_spectrum(self): 

171 raise NotImplementedError() 

172 

173 def plot_TCM_diagonal(self, energy, **kwargs): 

174 x_o = np.take(self.energy_o, (0, -1)) 

175 self.ax_tcm.plot(x_o, x_o + energy, **kwargs) 

176 

177 def set_title(self, *args, **kwargs): 

178 self.ax_occ_dos.set_title(*args, **kwargs) 

179 

180 

181class TCMPlotter(TCM): 

182 

183 def __init__(self, ksd, energy_o, energy_u, sigma, 

184 zero_fermilevel=True): 

185 eig_n, fermilevel = ksd.get_eig_n(zero_fermilevel) 

186 TCM.__init__(self, energy_o, energy_u, fermilevel) 

187 self.ksd = ksd 

188 self.sigma = sigma 

189 self.eig_n = eig_n 

190 

191 def plot_TCM(self, weight_p, **kwargs): 

192 # Calculate TCM 

193 tcm_ou = self.ksd.get_TCM(weight_p, self.eig_n, self.energy_o, 

194 self.energy_u, self.sigma) 

195 TCM.plot_TCM(self, tcm_ou, **kwargs) 

196 

197 def plot_DOS(self, weight_n=1.0, **kwargs): 

198 # Calculate DOS 

199 dos_o, dos_u = self.ksd.get_weighted_DOS(weight_n, self.eig_n, 

200 self.energy_o, 

201 self.energy_u, 

202 self.sigma) 

203 TCM.plot_DOS(self, dos_o, dos_u, **kwargs)