Coverage for gpaw/new/lcao/eigensolver.py: 92%

37 statements  

« prev     ^ index     » next       coverage.py v7.7.1, created at 2025-07-19 00:19 +0000

1import numpy as np 

2 

3from gpaw.new.eigensolver import Eigensolver, calculate_weights 

4from gpaw.new.lcao.hamiltonian import HamiltonianMatrixCalculator 

5from gpaw.new.lcao.wave_functions import LCAOWaveFunctions 

6from gpaw.new.energies import DFTEnergies 

7from gpaw.core.matrix import MatrixWithNoData 

8 

9 

10class LCAOEigensolver(Eigensolver): 

11 def __init__(self, 

12 basis, 

13 converge_bands='occupied'): 

14 self.basis = basis 

15 self.converge_bands = converge_bands 

16 

17 def iterate(self, 

18 ibzwfs, 

19 density, 

20 potential, 

21 hamiltonian, 

22 pot_calc=None, 

23 energies=None) -> tuple[float, float, DFTEnergies]: 

24 matrix_calculator = hamiltonian.create_hamiltonian_matrix_calculator( 

25 potential) 

26 

27 weight_un = calculate_weights(self.converge_bands, ibzwfs) 

28 eig_error = 0.0 

29 for wfs, weight_n in zip(ibzwfs, weight_un): 

30 _, temp_eig_error = \ 

31 self.iterate_kpt(wfs, weight_n, self.iterate1, 

32 matrix_calculator=matrix_calculator) 

33 if eig_error < temp_eig_error: 

34 eig_error = temp_eig_error 

35 

36 eig_error = ibzwfs.kpt_band_comm.max_scalar(eig_error) 

37 return eig_error, 0.0, energies 

38 

39 def iterate1(self, 

40 wfs: LCAOWaveFunctions, 

41 weight_n: np.ndarray, # XXX: Unused 

42 matrix_calculator: HamiltonianMatrixCalculator): 

43 H_MM = matrix_calculator.calculate_matrix(wfs) 

44 eig_M = H_MM.eighg(wfs.L_MM, wfs.domain_comm) 

45 C_Mn = H_MM # rename (H_MM now contains the eigenvectors) 

46 assert len(eig_M) >= wfs.nbands 

47 N = wfs.nbands 

48 wfs._eig_n = np.empty(wfs.nbands) 

49 wfs._eig_n[:] = eig_M[:N] 

50 comm = C_Mn.dist.comm 

51 if isinstance(wfs.C_nM, MatrixWithNoData): 

52 wfs.C_nM = wfs.C_nM.create() 

53 if comm.size == 1: 

54 wfs.C_nM.data[:] = C_Mn.data.T[:N] 

55 else: 

56 C_Mn = C_Mn.gather(broadcast=True) 

57 n1, n2 = wfs.C_nM.dist.my_row_range() 

58 wfs.C_nM.data[:] = C_Mn.data.T[n1:n2] 

59 

60 # Make sure wfs.C_nM and (lazy) wfs.P_ani are in sync: 

61 wfs._P_ani = None