Coverage for gpaw/response/pair_transitions.py: 89%
103 statements
« prev ^ index » next coverage.py v7.7.1, created at 2025-07-14 00:18 +0000
« prev ^ index » next coverage.py v7.7.1, created at 2025-07-14 00:18 +0000
1from __future__ import annotations
3import numpy as np
6class PairTransitions:
7 """Bookkeeping object for transitions in band and spin indices.
9 All transitions between different band and spin indices (for a given pair
10 of k-points k and k + q) are accounted for via single transition index t,
12 t (composite transition index): (n, s) -> (n', s')
13 """
15 def __init__(self, n1_t, n2_t, s1_t, s2_t):
16 """Construct the PairTransitions object.
18 Parameters
19 ----------
20 n1_t : np.array
21 Band index of k-point k for each transition t.
22 n2_t : np.array
23 Band index of k-point k + q for each transition t.
24 s1_t : np.array
25 Spin index of k-point k for each transition t.
26 s2_t : np.array
27 Spin index of k-point k + q for each transition t.
28 """
29 self.n1_t = n1_t
30 self.n2_t = n2_t
31 self.s1_t = s1_t
32 self.s2_t = s2_t
34 assert len(n2_t) == len(self)
35 assert len(s1_t) == len(self)
36 assert len(s2_t) == len(self)
38 def __len__(self):
39 return len(self.n1_t)
41 def get_band_indices(self):
42 return self.n1_t, self.n2_t
44 def get_spin_indices(self):
45 return self.s1_t, self.s2_t
47 def get_intraband_mask(self):
48 """Get mask for selecting intraband transitions."""
49 intraband_t = (self.n1_t == self.n2_t) & (self.s1_t == self.s2_t)
50 return intraband_t
52 @classmethod
53 def from_transitions_domain_arguments(cls, spincomponent,
54 nbands, nocc1, nocc2, nspins,
55 bandsummation) -> PairTransitions:
56 """Generate the band and spin transitions integration domain.
58 The integration domain is determined by the spin rotation (from spin
59 index s to spin index s'), the number of bands and spins in the
60 underlying ground state calculation as well as the band summation
61 scheme.
63 The integration domain automatically excludes transitions between two
64 occupied bands and two unoccupied bands respectively.
66 Parameters
67 ----------
68 spincomponent : str
69 Spin component (μν) of the pair function.
70 Currently, '00', 'uu', 'dd', '+-' and '-+' are implemented.
71 nbands : int
72 Maximum band index to include.
73 nocc1 : int
74 Number of completely filled bands in the ground state calculation
75 nocc2 : int
76 Number of non-empty bands in the ground state calculation
77 nspins : int
78 Number of spin channels in the ground state calculation (1 or 2)
79 bandsummation : str
80 Band (and spin) summation scheme for pairs of Kohn-Sham orbitals
81 'pairwise': sum over pairs of bands (and spins)
82 'double': double sum over band (and spin) indices.
83 """
84 n1_M, n2_M = get_band_transitions_domain(bandsummation, nbands,
85 nocc1=nocc1,
86 nocc2=nocc2)
87 s1_S, s2_S = get_spin_transitions_domain(bandsummation,
88 spincomponent, nspins)
90 n1_t, n2_t, s1_t, s2_t = transitions_in_composite_index(n1_M, n2_M,
91 s1_S, s2_S)
93 return cls(n1_t, n2_t, s1_t, s2_t)
96def get_band_transitions_domain(bandsummation, nbands, nocc1=None, nocc2=None):
97 """Get all pairs of bands to sum over
99 Parameters
100 ----------
101 bandsummation : str
102 Band summation method
103 nbands : int
104 number of bands
105 nocc1 : int
106 number of completely filled bands
107 nocc2 : int
108 number of non-empty bands
110 Returns
111 -------
112 n1_M : ndarray
113 band index 1, M = (n1, n2) composite index
114 n2_M : ndarray
115 band index 2, M = (n1, n2) composite index
116 """
117 _get_band_transitions_domain =\
118 create_get_band_transitions_domain(bandsummation)
119 n1_M, n2_M = _get_band_transitions_domain(nbands)
121 return remove_null_transitions(n1_M, n2_M, nocc1=nocc1, nocc2=nocc2)
124def create_get_band_transitions_domain(bandsummation):
125 """Creator component deciding how to carry out band summation."""
126 if bandsummation == 'pairwise':
127 return get_pairwise_band_transitions_domain
128 elif bandsummation == 'double':
129 return get_double_band_transitions_domain
130 raise ValueError(bandsummation)
133def get_double_band_transitions_domain(nbands):
134 """Make a simple double sum"""
135 n_n = np.arange(0, nbands)
136 m_m = np.arange(0, nbands)
137 n_nm, m_nm = np.meshgrid(n_n, m_m)
138 n_M, m_M = n_nm.flatten(), m_nm.flatten()
140 return n_M, m_M
143def get_pairwise_band_transitions_domain(nbands):
144 """Make a sum over all pairs"""
145 n_n = range(0, nbands)
146 n_M = []
147 m_M = []
148 for n in n_n:
149 m_m = range(n, nbands)
150 n_M += [n] * len(m_m)
151 m_M += m_m
153 return np.array(n_M), np.array(m_M)
156def remove_null_transitions(n1_M, n2_M, nocc1=None, nocc2=None):
157 """Remove pairs of bands, between which transitions are impossible"""
158 n1_newM = []
159 n2_newM = []
160 for n1, n2 in zip(n1_M, n2_M):
161 if nocc1 is not None and (n1 < nocc1 and n2 < nocc1):
162 continue # both bands are fully occupied
163 elif nocc2 is not None and (n1 >= nocc2 and n2 >= nocc2):
164 continue # both bands are completely unoccupied
165 n1_newM.append(n1)
166 n2_newM.append(n2)
168 return np.array(n1_newM), np.array(n2_newM)
171def get_spin_transitions_domain(bandsummation, spincomponent, nspins):
172 """Get structure of the sum over spins
174 Parameters
175 ----------
176 bandsummation : str
177 Band summation method
178 spincomponent : str
179 Spin component (μν) of the pair function.
180 Currently, '00', 'uu', 'dd', '+-' and '-+' are implemented.
181 nspins : int
182 number of spin channels in ground state calculation
184 Returns
185 -------
186 s1_s : ndarray
187 spin index 1, S = (s1, s2) composite index
188 s2_S : ndarray
189 spin index 2, S = (s1, s2) composite index
190 """
191 _get_spin_transitions_domain =\
192 create_get_spin_transitions_domain(bandsummation)
193 return _get_spin_transitions_domain(spincomponent, nspins)
196def create_get_spin_transitions_domain(bandsummation):
197 """Creator component deciding how to carry out spin summation."""
198 if bandsummation == 'pairwise':
199 return get_pairwise_spin_transitions_domain
200 elif bandsummation == 'double':
201 return get_double_spin_transitions_domain
202 raise ValueError(bandsummation)
205def get_double_spin_transitions_domain(spincomponent, nspins):
206 """Usual spin rotations forward in time"""
207 if nspins == 1:
208 if spincomponent == '00':
209 s1_S = [0]
210 s2_S = [0]
211 else:
212 raise ValueError(spincomponent, nspins)
213 else:
214 if spincomponent == '00':
215 s1_S = [0, 1]
216 s2_S = [0, 1]
217 elif spincomponent == 'uu':
218 s1_S = [0]
219 s2_S = [0]
220 elif spincomponent == 'dd':
221 s1_S = [1]
222 s2_S = [1]
223 elif spincomponent == '+-':
224 s1_S = [0] # spin up
225 s2_S = [1] # spin down
226 elif spincomponent == '-+':
227 s1_S = [1] # spin down
228 s2_S = [0] # spin up
229 else:
230 raise ValueError(spincomponent)
232 return np.array(s1_S), np.array(s2_S)
235def get_pairwise_spin_transitions_domain(spincomponent, nspins):
236 """In a sum over pairs, transitions including a spin rotation may have to
237 include terms, propagating backwards in time."""
238 if spincomponent in ['+-', '-+']:
239 assert nspins == 2
240 return np.array([0, 1]), np.array([1, 0])
241 else:
242 return get_double_spin_transitions_domain(spincomponent, nspins)
245def transitions_in_composite_index(n1_M, n2_M, s1_S, s2_S):
246 """Use a composite index t for transitions (n, s) -> (n', s')."""
247 n1_MS, s1_MS = np.meshgrid(n1_M, s1_S)
248 n2_MS, s2_MS = np.meshgrid(n2_M, s2_S)
249 return n1_MS.flatten(), n2_MS.flatten(), s1_MS.flatten(), s2_MS.flatten()