Coverage for gpaw/new/smearing.py: 82%

34 statements  

« prev     ^ index     » next       coverage.py v7.7.1, created at 2025-07-08 00:17 +0000

1from __future__ import annotations 

2from gpaw.occupations import create_occ_calc, ParallelLayout 

3from gpaw.band_descriptor import BandDescriptor 

4from gpaw.typing import ArrayLike2D, Array2D 

5 

6 

7class OccupationNumberCalculator: 

8 def __init__(self, 

9 dct, 

10 pbc, 

11 ibz, 

12 nbands, 

13 comms, 

14 magmom_v, 

15 ncomponents, 

16 nelectrons, 

17 rcell): 

18 if not dct: 

19 if pbc.any(): 

20 dct = {'name': 'fermi-dirac', 

21 'width': 0.1} # eV 

22 else: 

23 dct = {'width': 0.0} 

24 

25 if dct.get('fixmagmom'): 

26 if ncomponents == 1: 

27 dct = dct.copy() 

28 del dct['fixmagmom'] 

29 assert ncomponents == 2 

30 

31 kwargs = dct.copy() 

32 name = kwargs.pop('name', '') 

33 assert name != 'mom' 

34 

35 bd = BandDescriptor(nbands) # dummy 

36 # Note that eigenvalues are not distributed over 

37 # the band communicator. 

38 self.occ = create_occ_calc( 

39 dct, 

40 parallel_layout=ParallelLayout(bd, 

41 comms['k'], 

42 comms['K']), 

43 nbands=nbands, 

44 nkpts=len(ibz), 

45 nelectrons=nelectrons, 

46 nspins=ncomponents % 3, 

47 fixed_magmom_value=magmom_v[2], 

48 rcell=rcell, 

49 monkhorst_pack_size=getattr(ibz.bz, 'size_c', None), 

50 bz2ibzmap=ibz.bz2ibz_K) 

51 self.extrapolate_factor = self.occ.extrapolate_factor 

52 

53 def __str__(self): 

54 return str(self.occ) 

55 

56 def _set_nbands(self, nbands): 

57 bd, kpt_comm, domain_comm = self.occ.parallel_layout 

58 self.occ = self.occ.copy( 

59 parallel_layout=ParallelLayout(BandDescriptor(nbands), 

60 kpt_comm, domain_comm)) 

61 

62 def calculate(self, 

63 nelectrons: float, 

64 eigenvalues: ArrayLike2D, 

65 weights: list[float], 

66 fermi_levels_guess: list[float] = None, 

67 fix_fermi_level: bool = False 

68 ) -> tuple[Array2D, list[float], float]: 

69 occs, fls, e = self.occ.calculate(nelectrons, eigenvalues, weights, 

70 fermi_levels_guess, fix_fermi_level) 

71 return occs, fls, e 

72 

73 def initialize_reference_orbitals(self): 

74 try: 

75 self.occ.initialize_reference_orbitals() 

76 except AttributeError: 

77 pass