Coverage for gpaw/lcao/tci.py: 99%
217 statements
« prev ^ index » next coverage.py v7.7.1, created at 2025-07-20 00:19 +0000
« prev ^ index » next coverage.py v7.7.1, created at 2025-07-20 00:19 +0000
1import numpy as np
2import scipy.sparse as sparse
3from ase.neighborlist import PrimitiveNeighborList
4# from ase.utils.timing import timer
5from gpaw.utilities.tools import tri2full
7# from gpaw import debug
8from gpaw.lcao.overlap import (FourierTransformer, TwoSiteOverlapCalculator,
9 ManySiteOverlapCalculator,
10 AtomicDisplacement, NullPhases, BlochPhases,
11 DerivativeAtomicDisplacement)
14def get_cutoffs(f_Ij):
15 rcutmax_I = []
16 for f_j in f_Ij:
17 rcutmax = 0.001 # 'paranoid zero'
18 for f in f_j:
19 rcutmax = max(rcutmax, f.get_cutoff())
20 rcutmax_I.append(rcutmax)
21 return rcutmax_I
24def get_lvalues(f_Ij):
25 return [[f.get_angular_momentum_number() for f in f_j] for f_j in f_Ij]
28class AtomPairRegistry:
29 def __init__(self, cutoff_a, pbc_c, cell_cv, spos_ac):
30 nl = PrimitiveNeighborList(cutoff_a, skin=0, sorted=True,
31 self_interaction=True,
32 use_scaled_positions=True)
34 nl.update(pbc=pbc_c, cell=cell_cv, coordinates=spos_ac)
35 r_and_offset_aao = {}
37 def add(a1, a2, R_c, offset):
38 r_and_offset_aao.setdefault((a1, a2), []).append((R_c, offset))
40 for a1, spos1_c in enumerate(spos_ac):
41 a2_a, offsets = nl.get_neighbors(a1)
42 for a2, offset in zip(a2_a, offsets):
43 spos2_c = spos_ac[a2] + offset
45 R_c = np.dot(spos2_c - spos1_c, cell_cv)
46 add(a1, a2, R_c, offset)
47 if a1 != a2 or offset.any():
48 add(a2, a1, -R_c, -offset)
49 self.r_and_offset_aao = r_and_offset_aao
51 def get(self, a1, a2):
52 R_c_and_offset_a = self.r_and_offset_aao.get((a1, a2))
53 return R_c_and_offset_a
55 def get_atompairs(self):
56 return list(sorted(self.r_and_offset_aao))
59class TCIExpansions:
60 def __init__(self, phit_Ij, pt_Ij, I_a):
61 assert len(pt_Ij) == len(phit_Ij)
63 # Cutoffs by species:
64 pt_rcmax_I = get_cutoffs(pt_Ij)
65 phit_rcmax_I = get_cutoffs(phit_Ij)
66 rcmax_I = [max(rc1, rc2) for rc1, rc2
67 in zip(pt_rcmax_I, phit_rcmax_I)]
69 transformer = FourierTransformer(rcut=max(rcmax_I + [1e-3]), N=2**10)
70 tsoc = TwoSiteOverlapCalculator(transformer)
71 msoc = ManySiteOverlapCalculator(tsoc, I_a, I_a)
72 phit_Ijq = msoc.transform(phit_Ij)
73 pt_Ijq = msoc.transform(pt_Ij)
74 pt_l_Ij = get_lvalues(pt_Ij)
75 phit_l_Ij = get_lvalues(phit_Ij)
76 self.O_expansions = msoc.calculate_expansions(phit_l_Ij, phit_Ijq,
77 phit_l_Ij, phit_Ijq)
78 self.T_expansions = msoc.calculate_kinetic_expansions(phit_l_Ij,
79 phit_Ijq)
80 self.P_expansions = msoc.calculate_expansions(pt_l_Ij, pt_Ijq,
81 phit_l_Ij, phit_Ijq)
82 self.I_a = I_a # Actually I_a belongs outside, like spos_ac.
83 self.rcmax_I = rcmax_I
84 self.phit_rcmax_I = phit_rcmax_I
85 self.pt_rcmax_I = pt_rcmax_I
87 @classmethod
88 def new_from_setups(cls, setups):
89 I_setup = {}
90 setups_I = list(setups.setups.values())
91 for I, setup in enumerate(setups_I):
92 I_setup[setup] = I
93 I_a = [I_setup[setup] for setup in setups]
95 return TCIExpansions([s.basis_functions_J for s in setups_I],
96 [s.pt_j for s in setups_I],
97 I_a)
99 def get_tci_calculator(self, cell_cv, spos_ac, pbc_c, ibzk_qc, dtype):
100 return TCICalculator(self, cell_cv, spos_ac, pbc_c, ibzk_qc, dtype)
102 def get_manytci_calculator(self, setups, gd, spos_ac, ibzk_qc, dtype,
103 timer):
104 return ManyTCICalculator(self, setups, gd, spos_ac, ibzk_qc, dtype,
105 timer)
108class TCICalculator:
109 """High-level two-center integral calculator.
111 This object is not aware of parallelization. It works with any
112 pair of atoms a1, a2.
114 Create the object and calculate any interatomic overlap matrix as below.
116 tci = TCI(...)
118 Projector/basis overlap <pt_i^a1|phi_mu> between atoms a1, a2:
120 P_qim = tci.P(a1, a2)
122 Derivatives of the above with respect to movement of a2:
124 dPdR_qvim = tci.dPdR(a1, a2)
126 Basis/basis overlap and kinetic matrix elements between atoms a1, a2:
128 O_qmm, T_qmm = tci.O_T(a1, a2)
130 Derivative of the above wrt. position of a2:
132 dOdR_qvmm, dTdR_qvmm = tci.dOdR_dTdR(a1, a2)
134 """
135 def __init__(self, tciexpansions, cell_cv, spos_ac, pbc_c, ibzk_qc,
136 dtype):
138 self.tciexpansions = tciexpansions
139 self.dtype = dtype
141 # XXX It is somewhat nasty that rcmax depends on how long our
142 # longest orbital happens to be
143 # Cutoffs by atom:
144 I_a = tciexpansions.I_a
145 cutoff_a = [tciexpansions.rcmax_I[I] for I in I_a]
146 self.pt_rcmax_a = np.array([tciexpansions.pt_rcmax_I[I] for I in I_a])
147 self.phit_rcmax_a = np.array([tciexpansions.phit_rcmax_I[I]
148 for I in I_a])
150 self.a1a2 = AtomPairRegistry(cutoff_a, pbc_c, cell_cv, spos_ac)
152 self.ibzk_qc = ibzk_qc
153 if ibzk_qc.any():
154 self.get_phases = BlochPhases
155 else:
156 self.get_phases = NullPhases
158 self.O_T = self._tci_shortcut(False, False)
159 self.P = self._tci_shortcut(True, False)
160 self.dOdR_dTdR = self._tci_shortcut(False, True)
161 self.dPdR = self._tci_shortcut(True, True)
163 def _tci_shortcut(self, P, derivative):
164 def calculate(a1, a2):
165 return self._calculate(a1, a2, P, derivative)
166 return calculate
168 def _calculate(self, a1, a2, P=False, derivative=False):
169 """Calculate overlap of functions between atoms a1 and a2."""
171 # We want to see quickly if there is no overlap because distance
172 # outside bounding spheres.
174 R_c_and_offset_a = self.a1a2.get(a1, a2)
175 if R_c_and_offset_a is None:
176 return None if P else (None, None)
178 rcut1 = self.pt_rcmax_a[a1] if P else self.phit_rcmax_a[a1]
179 rcut2 = self.phit_rcmax_a[a2]
180 maxdist = rcut1 + rcut2
182 # Filter out displacements larger than maxdist:
183 R_c_and_offset_a = [obj for obj in R_c_and_offset_a
184 if np.linalg.norm(obj[0]) < maxdist]
185 if not R_c_and_offset_a: # There was no overlap after all
186 return None if P else (None, None)
188 dtype = self.dtype
189 get_phases = self.get_phases
191 displacement = (DerivativeAtomicDisplacement
192 if derivative
193 else AtomicDisplacement)
194 ibzk_qc = self.ibzk_qc
195 nq = len(ibzk_qc)
196 phit_rcmax_a = self.phit_rcmax_a
197 pt_rcmax_a = self.pt_rcmax_a
199 shape = (nq, 3) if derivative else (nq,)
201 if P:
202 P_expansion = self.tciexpansions.P_expansions.get(a1, a2)
203 obj = P_qim = P_expansion.zeros(shape, dtype=dtype)
204 else:
205 O_expansion = self.tciexpansions.O_expansions.get(a1, a2)
206 T_expansion = self.tciexpansions.T_expansions.get(a1, a2)
207 O_qmm = O_expansion.zeros(shape, dtype=dtype)
208 T_qmm = T_expansion.zeros(shape, dtype=dtype)
209 obj = O_qmm, T_qmm
211 for R_c, offset in R_c_and_offset_a:
212 norm = np.linalg.norm(R_c)
213 phases = get_phases(ibzk_qc, offset)
215 disp = displacement(None, a1, a2, R_c, offset, phases)
217 if P:
218 assert norm < pt_rcmax_a[a1] + phit_rcmax_a[a2]
219 disp.evaluate_overlap(P_expansion, P_qim)
220 else:
221 assert norm < phit_rcmax_a[a1] + phit_rcmax_a[a2]
222 disp.evaluate_overlap(O_expansion, O_qmm)
223 disp.evaluate_overlap(T_expansion, T_qmm)
225 return obj
228class ManyTCICalculator:
229 def __init__(self, tciexpansions, setups, gd, spos_ac, ibzk_qc, dtype,
230 timer):
231 self.tci = tciexpansions.get_tci_calculator(gd.cell_cv, spos_ac,
232 gd.pbc_c,
233 ibzk_qc, dtype)
235 self.setups = setups
236 self.dtype = dtype
237 self.Pindices = setups.projector_indices()
238 self.Mindices = setups.basis_indices()
239 self.natoms = len(setups)
240 self.nq = len(ibzk_qc)
241 self.nao = self.Mindices.max
242 self.timer = timer
244 # @timer('tci-projectors')
245 def P_aqMi(self, my_atom_indices, derivative=False):
246 P_axMi = {}
247 if derivative:
248 P = self.tci.dPdR
250 def empty(nI):
251 return np.empty((self.nq, 3, self.nao, nI), self.dtype)
252 else:
253 P = self.tci.P
255 def empty(nI):
256 return np.empty((self.nq, self.nao, nI), self.dtype)
258 Mindices = self.Mindices
260 for a1 in my_atom_indices:
261 P_xMi = empty(self.setups[a1].ni)
263 for a2 in range(self.natoms):
264 N1, N2 = Mindices[a2]
265 P_xmi = P_xMi[..., N1:N2, :]
266 P_xim = P(a1, a2)
267 if P_xim is None:
268 P_xmi[:] = 0.0
269 else:
270 P_xmi[:] = P_xim.swapaxes(-2, -1).conj()
271 P_axMi[a1] = P_xMi
273 if derivative:
274 for a in P_axMi:
275 P_axMi[a] *= -1.0
276 return P_axMi
278 # @timer('tci-sparseprojectors')
279 def P_qIM(self, my_atom_indices):
280 nq = self.nq
281 P = self.tci.P
282 P_qIM = [sparse.dok_matrix((self.Pindices.max, self.Mindices.max),
283 dtype=self.dtype)
284 for _ in range(nq)]
286 for a1 in my_atom_indices:
287 I1, I2 = self.Pindices[a1]
289 # We can stride a2 over e.g. bd.comm and then do bd.comm.sum().
290 # How should we do comm.sum() on a sparse matrix though?
291 for a2 in range(self.natoms):
292 M1, M2 = self.Mindices[a2]
293 P_qim = P(a1, a2)
294 if P_qim is not None:
295 for q in range(nq):
296 P_qIM[q][I1:I2, M1:M2] = P_qim[q]
297 P_qIM = [P_IM.tocsr() for P_IM in P_qIM]
298 return P_qIM
300 # @timer('tci-bfs')
301 def O_qMM_T_qMM(self, gdcomm, Mstart, Mstop, ignore_upper=False,
302 derivative=False):
303 mynao = Mstop - Mstart
304 Mindices = self.Mindices
306 if derivative:
307 O_T = self.tci.dOdR_dTdR
308 shape = (self.nq, 3, mynao, self.nao)
309 else:
310 O_T = self.tci.O_T
311 shape = (self.nq, mynao, self.nao)
313 O_xMM = np.zeros(shape, self.dtype)
314 T_xMM = np.zeros(shape, self.dtype)
316 # XXX the a1/a2 loops are not yet well load balanced.
317 for a1 in range(self.natoms):
318 M1, M2 = Mindices[a1]
319 if M2 <= Mstart or M1 >= Mstop:
320 continue
322 myM1 = max(M1 - Mstart, 0)
323 myM2 = min(M2 - Mstart, mynao)
324 nM = myM2 - myM1
326 assert nM > 0, nM
328 a2max = a1 + 1 # if not derivative else self.natoms
330 for a2 in range(gdcomm.rank, a2max, gdcomm.size):
331 O_xmm, T_xmm = O_T(a1, a2)
332 if O_xmm is None:
333 continue
335 N1, N2 = Mindices[a2]
336 m1 = max(Mstart - M1, 0)
337 m2 = m1 + nM # (Slice may go beyond end of matrix but OK)
338 O_xmm = O_xmm[..., m1:m2, :]
339 T_xmm = T_xmm[..., m1:m2, :]
340 O_xMM[..., myM1:myM2, N1:N2] = O_xmm
341 T_xMM[..., myM1:myM2, N1:N2] = T_xmm
343 if not ignore_upper and O_xMM.size: # reshape() fails on size-0 arrays
344 assert mynao == self.nao
345 assert O_xMM.shape[-2:] == (self.nao, self.nao)
346 if derivative:
347 def lumap(arr, out):
348 np.conj(arr, out)
349 out *= -1.0
350 else:
351 lumap = np.conj
353 for arr_xMM in [O_xMM, T_xMM]:
354 for tmp_MM in arr_xMM.reshape(-1, self.nao, self.nao):
355 tri2full(tmp_MM, UL='L', map=lumap)
357 return O_xMM, T_xMM