Coverage for gpaw/lcaotddft/tcm.py: 10%
146 statements
« prev ^ index » next coverage.py v7.7.1, created at 2025-07-09 00:21 +0000
« prev ^ index » next coverage.py v7.7.1, created at 2025-07-09 00:21 +0000
1import numpy as np
4def generate_gridspec(**kwargs):
5 from matplotlib.gridspec import GridSpec
6 width = 0.84
7 bottom = 0.12
8 left = 0.12
9 return GridSpec(2, 2, width_ratios=[3, 1], height_ratios=[1, 3],
10 bottom=bottom, top=bottom + width,
11 left=left, right=left + width,
12 **kwargs)
15def plot_DOS(ax, energy_e, dos_e, base_e, dos_min, dos_max,
16 flip=False, fill=None, line=None):
17 ax.xaxis.set_ticklabels([])
18 ax.yaxis.set_ticklabels([])
19 ax.spines['right'].set_visible(False)
20 ax.spines['top'].set_visible(False)
21 ax.yaxis.set_ticks_position('left')
22 ax.xaxis.set_ticks_position('bottom')
23 if flip:
24 set_label = ax.set_xlabel
25 fill_between = ax.fill_betweenx
26 set_energy_lim = ax.set_ylim
27 set_dos_lim = ax.set_xlim
29 def plot(x, y, *args, **kwargs):
30 return ax.plot(y, x, *args, **kwargs)
31 else:
32 set_label = ax.set_ylabel
33 fill_between = ax.fill_between
34 set_energy_lim = ax.set_xlim
35 set_dos_lim = ax.set_ylim
37 def plot(x, y, *args, **kwargs):
38 return ax.plot(x, y, *args, **kwargs)
39 if fill:
40 fill_between(energy_e, base_e, dos_e + base_e, **fill)
41 if line:
42 plot(energy_e, dos_e, **line)
43 set_label('DOS', labelpad=0)
44 set_energy_lim(np.take(energy_e, (0, -1)))
45 set_dos_lim(dos_min, dos_max)
48class TCM:
50 def __init__(self, energy_o, energy_u, fermilevel):
51 self.energy_o = energy_o
52 self.energy_u = energy_u
53 self.fermilevel = fermilevel
55 self.base_o = np.zeros_like(energy_o)
56 self.base_u = np.zeros_like(energy_u)
58 def __getattr__(self, attr):
59 import matplotlib.pyplot as plt
60 # Generate axis only when needed
61 if attr in ['ax_occ_dos', 'ax_unocc_dos', 'ax_tcm']:
62 gs = generate_gridspec(hspace=0.05, wspace=0.05)
63 self.ax_occ_dos = plt.subplot(gs[0])
64 self.ax_unocc_dos = plt.subplot(gs[3])
65 self.ax_tcm = plt.subplot(gs[2])
66 return getattr(self, attr)
67 if attr in ['ax_spec']:
68 gs = generate_gridspec(hspace=0.8, wspace=0.8)
69 self.ax_spec = plt.subplot(gs[1])
70 return getattr(self, attr)
71 if attr in ['ax_cbar']:
72 self.ax_cbar = plt.axes((0.15, 0.6, 0.02, 0.1))
73 return getattr(self, attr)
74 raise AttributeError('%s object has no attribute %s' %
75 (repr(self.__class__.__name__), repr(attr)))
77 def plot_TCM(self, tcm_ou, vmax='80%', vmin='symmetrize', cmap='seismic',
78 log=False, colorbar=True, lw=None):
79 import matplotlib as mpl
80 import matplotlib.pyplot as plt
81 from matplotlib.colors import Normalize, LogNorm
82 if lw is None:
83 lw = mpl.rcParams['lines.linewidth']
84 energy_o = self.energy_o
85 energy_u = self.energy_u
86 fermilevel = self.fermilevel
88 tcmmax = np.max(np.absolute(tcm_ou))
89 print('tcmmax', tcmmax)
91 # Plot TCM
92 ax = self.ax_tcm
93 plt.sca(ax)
94 plt.cla()
95 if isinstance(vmax, str):
96 assert vmax[-1] == '%'
97 tcmmax = np.max(np.absolute(tcm_ou))
98 vmax = tcmmax * float(vmax[:-1]) / 100.0
99 if vmin == 'symmetrize':
100 vmin = -vmax
101 if tcm_ou.dtype == complex:
102 linecolor = 'w'
103 from matplotlib.colors import hsv_to_rgb
105 def transform_to_hsv(z, rmin, rmax, hue_start=90):
106 amp = np.absolute(z) # **2
107 amp = np.where(amp < rmin, rmin, amp)
108 amp = np.where(amp > rmax, rmax, amp)
109 ph = np.angle(z, deg=1) + hue_start
110 h = (ph % 360) / 360
111 s = 1.85 * np.ones_like(h)
112 v = (amp - rmin) / (rmax - rmin)
113 return hsv_to_rgb(np.dstack((h, s, v)))
115 img = transform_to_hsv(tcm_ou.T, 0, vmax)
116 plt.imshow(img, origin='lower',
117 extent=[energy_o[0], energy_o[-1],
118 energy_u[0], energy_u[-1]],
119 interpolation='bilinear',
120 )
121 else:
122 linecolor = 'k'
123 if cmap == 'magma':
124 linecolor = 'w'
125 if log:
126 norm = LogNorm(vmin=vmin, vmax=vmax)
127 else:
128 norm = Normalize(vmin=vmin, vmax=vmax)
129 plt.pcolormesh(energy_o, energy_u, tcm_ou.T,
130 cmap=cmap, rasterized=True, norm=norm,
131 shading='auto')
132 if colorbar:
133 ax = self.ax_cbar
134 ax.clear()
135 cb = plt.colorbar(cax=ax)
136 cb.outline.set_edgecolor(linecolor)
137 ax.tick_params(axis='both', colors=linecolor)
138 # ax.yaxis.label.set_color(linecolor)
139 # ax.xaxis.label.set_color(linecolor)
140 ax = self.ax_tcm
141 plt.sca(ax)
142 plt.axhline(fermilevel, c=linecolor, lw=lw)
143 plt.axvline(fermilevel, c=linecolor, lw=lw)
145 ax.tick_params(axis='both', which='major', pad=2)
146 plt.xlabel(r'Occ. energy $\varepsilon_{o}$ (eV)', labelpad=0)
147 plt.ylabel(r'Unocc. energy $\varepsilon_{u}$ (eV)', labelpad=0)
148 plt.xlim(np.take(energy_o, (0, -1)))
149 plt.ylim(np.take(energy_u, (0, -1)))
151 def plot_DOS(self, dos_o, dos_u, stack=False,
152 fill={'color': '0.8'}, line={'color': 'k'}):
153 # Plot DOSes
154 if stack:
155 base_o = self.base_o
156 base_u = self.base_u
157 else:
158 base_o = np.zeros_like(self.energy_o)
159 base_u = np.zeros_like(self.energy_u)
160 dos_min = 0.0
161 dos_max = 1.01 * max(np.max(dos_o), np.max(dos_u))
162 plot_DOS(self.ax_occ_dos, self.energy_o, dos_o, base_o,
163 dos_min, dos_max, flip=False, fill=fill, line=line)
164 plot_DOS(self.ax_unocc_dos, self.energy_u, dos_u, base_u,
165 dos_min, dos_max, flip=True, fill=fill, line=line)
166 if stack:
167 self.base_o += dos_o
168 self.base_u += dos_u
170 def plot_spectrum(self):
171 raise NotImplementedError()
173 def plot_TCM_diagonal(self, energy, **kwargs):
174 x_o = np.take(self.energy_o, (0, -1))
175 self.ax_tcm.plot(x_o, x_o + energy, **kwargs)
177 def set_title(self, *args, **kwargs):
178 self.ax_occ_dos.set_title(*args, **kwargs)
181class TCMPlotter(TCM):
183 def __init__(self, ksd, energy_o, energy_u, sigma,
184 zero_fermilevel=True):
185 eig_n, fermilevel = ksd.get_eig_n(zero_fermilevel)
186 TCM.__init__(self, energy_o, energy_u, fermilevel)
187 self.ksd = ksd
188 self.sigma = sigma
189 self.eig_n = eig_n
191 def plot_TCM(self, weight_p, **kwargs):
192 # Calculate TCM
193 tcm_ou = self.ksd.get_TCM(weight_p, self.eig_n, self.energy_o,
194 self.energy_u, self.sigma)
195 TCM.plot_TCM(self, tcm_ou, **kwargs)
197 def plot_DOS(self, weight_n=1.0, **kwargs):
198 # Calculate DOS
199 dos_o, dos_u = self.ksd.get_weighted_DOS(weight_n, self.eig_n,
200 self.energy_o,
201 self.energy_u,
202 self.sigma)
203 TCM.plot_DOS(self, dos_o, dos_u, **kwargs)