Coverage for gpaw/new/pw/fulldiag.py: 15%
91 statements
« prev ^ index » next coverage.py v7.7.1, created at 2025-07-20 00:19 +0000
« prev ^ index » next coverage.py v7.7.1, created at 2025-07-20 00:19 +0000
1from __future__ import annotations
3import numpy as np
4from gpaw.core.atom_arrays import AtomArrays
5from gpaw.core.matrix import Matrix, create_distribution
6from gpaw.core.plane_waves import (PWAtomCenteredFunctions,
7 PWArray, PWDesc)
8from gpaw.core.uniform_grid import UGArray
9from gpaw.new.pwfd.wave_functions import PWFDWaveFunctions
10from gpaw.typing import Array2D
11from gpaw.new.ibzwfs import IBZWaveFunctions
12from gpaw.new.wave_functions import WaveFunctions
13from gpaw.new.potential import Potential
14from gpaw.new.smearing import OccupationNumberCalculator
17def pw_matrix(pw: PWDesc,
18 pt_aiG: PWAtomCenteredFunctions,
19 dH_aii: AtomArrays,
20 dS_aii: list[Array2D],
21 vt_R: UGArray,
22 dedtaut_R: UGArray | None,
23 comm) -> tuple[Matrix, Matrix]:
24 """Calculate H and S matrices in plane-wave basis.
26 :::
28 _ _ _ _
29 / -iG.r ~ iG.r _
30 O = | e O e dr
31 GG' /
33 :::
35 ~ ^ ~ _ _ _ --- ~a _ _a a ~ _ _a
36 H = T + v(r) δ(r-r') + < p (r-R ) ΔH p (r'-R )
37 --- i ij j
38 aij
40 :::
42 ~ _ _ --- ~a _ _a a ~ _ _a
43 S = δ(r-r') + < p (r-R ) ΔS p (r'-R )
44 --- i ij j
45 aij
46 """
47 assert pw.dtype == complex
48 npw = pw.shape[0]
49 dist = create_distribution(npw, npw, comm, -1, 1)
50 H_GG = dist.matrix(complex)
51 S_GG = dist.matrix(complex)
52 G1, G2 = dist.my_row_range()
54 x_G = pw.empty()
55 assert isinstance(x_G, PWArray) # Fix this!
56 x_R = vt_R.desc.new(dtype=complex).zeros()
57 assert isinstance(x_R, UGArray) # Fix this!
58 dv = pw.dv
60 for G in range(G1, G2):
61 x_G.data[:] = 0.0
62 x_G.data[G] = 1.0
63 x_G.ifft(out=x_R)
64 x_R.data *= vt_R.data
65 x_R.fft(out=x_G)
66 H_GG.data[G - G1] = dv * x_G.data
68 if dedtaut_R is not None:
69 G_Gv = pw.reciprocal_vectors()
70 for G in range(G1, G2):
71 for v in range(3):
72 x_G.data[:] = 0.0
73 x_G.data[G] = 1j * G_Gv[G, v]
74 x_G.ifft(out=x_R)
75 x_R.data *= dedtaut_R.data
76 x_R.fft(out=x_G)
77 H_GG.data[G - G1] += -0.5j * dv * G_Gv[:, v] * x_G.data
79 H_GG.add_to_diagonal(dv * pw.ekin_G[G1:G2])
80 S_GG.data[:] = 0.0
81 S_GG.add_to_diagonal(dv)
83 pt_aiG._lazy_init()
84 assert pt_aiG._lfc is not None
85 f_GI = pt_aiG._lfc.expand()
86 nI = f_GI.shape[1]
87 dH_II = np.zeros((nI, nI))
88 dS_II = np.zeros((nI, nI))
89 I1 = 0
90 for a, dH_ii in dH_aii.items():
91 dS_ii = dS_aii[a]
92 I2 = I1 + len(dS_ii)
93 dH_II[I1:I2, I1:I2] = dH_ii
94 dS_II[I1:I2, I1:I2] = dS_ii
95 I1 = I2
97 H_GG.data += (f_GI[G1:G2].conj() @ dH_II) @ f_GI.T
98 S_GG.data += (f_GI[G1:G2].conj() @ dS_II) @ f_GI.T
100 return H_GG, S_GG
103def diagonalize(potential: Potential,
104 ibzwfs: IBZWaveFunctions,
105 occ_calc: OccupationNumberCalculator,
106 nbands: int,
107 nelectrons: float) -> IBZWaveFunctions:
108 """Diagonalize hamiltonian in plane-wave basis."""
109 vt_sR = potential.vt_sR
110 dH_asii = potential.dH_asii
111 dedtaut_sR: UGArray | list[None] = [None] * len(vt_sR)
112 if potential.dedtaut_sR is not None:
113 dedtaut_sR = potential.dedtaut_sR
115 band_comm = ibzwfs.band_comm
117 wfs_qs: list[list[WaveFunctions]] = []
118 for wfs_s in ibzwfs.wfs_qs:
119 wfs_qs.append([])
120 for wfs in wfs_s:
121 dS_aii = [setup.dO_ii for setup in wfs.setups]
122 assert isinstance(wfs, PWFDWaveFunctions)
123 assert isinstance(wfs.pt_aiX, PWAtomCenteredFunctions)
124 pw = wfs.psit_nX.desc
125 H_GG, S_GG = pw_matrix(pw,
126 wfs.pt_aiX,
127 dH_asii[:, wfs.spin],
128 dS_aii,
129 vt_sR[wfs.spin],
130 dedtaut_sR[wfs.spin],
131 band_comm)
133 eig_n = H_GG.eigh(S_GG, limit=nbands)
134 H_GG.complex_conjugate()
135 assert eig_n[0] > -1000, 'See issue #241'
136 psit_nG = pw.empty(nbands, comm=band_comm)
137 mynbands, nG = psit_nG.data.shape
138 maxmynbands = (nbands + band_comm.size - 1) // band_comm.size
139 C_nG = H_GG.new(
140 dist=(band_comm, band_comm.size, 1, maxmynbands, 1))
141 H_GG.redist(C_nG)
142 psit_nG.data[:] = C_nG.data[:mynbands]
143 new_wfs = PWFDWaveFunctions.from_wfs(wfs, psit_nX=psit_nG)
144 new_wfs._eig_n = eig_n
145 wfs_qs[-1].append(new_wfs)
147 new_ibzwfs = IBZWaveFunctions(
148 ibzwfs.ibz,
149 ncomponents=ibzwfs.ncomponents,
150 wfs_qs=wfs_qs,
151 kpt_comm=ibzwfs.kpt_comm,
152 kpt_band_comm=ibzwfs.kpt_band_comm,
153 comm=ibzwfs.comm)
155 new_ibzwfs.calculate_occs(occ_calc, nelectrons)
157 return new_ibzwfs