Coverage for gpaw/hybrids/paw.py: 85%

54 statements  

« prev     ^ index     » next       coverage.py v7.7.1, created at 2025-07-14 00:18 +0000

1from __future__ import annotations 

2from typing import TYPE_CHECKING 

3from typing import NamedTuple, Dict, List 

4 

5import numpy as np 

6 

7from gpaw.mpi import broadcast 

8from gpaw.utilities import (pack_atomic_matrices, unpack_atomic_matrices, 

9 unpack_density, unpack_hermitian, packed_index) 

10 

11 

12class PAWThings(NamedTuple): 

13 VC_aii: Dict[int, np.ndarray | None] 

14 VV_aii: Dict[int, np.ndarray] # distributed 

15 Delta_aiiL: List[np.ndarray] 

16 

17 

18def calculate_paw_stuff(wfs, dens) -> List[PAWThings]: 

19 D_asp = dens.D_asp 

20 comm = D_asp.partition.comm 

21 if comm.size != wfs.world.size: 

22 D_sP = pack_atomic_matrices(D_asp) 

23 D_sP = broadcast(D_sP if comm.rank == 0 else None, comm=comm) 

24 D_asp = unpack_atomic_matrices(D_sP, wfs.setups) 

25 rank_a = np.linspace(0, wfs.world.size, len(wfs.setups), 

26 endpoint=False).astype(int) 

27 D_asp = {a: D_sp for a, D_sp in D_asp.items() 

28 if rank_a[a] == wfs.world.rank} 

29 

30 VV_saii: List[Dict[int, np.ndarray]] = [{} for s in range(dens.nspins)] 

31 for a, D_sp in D_asp.items(): 

32 data = wfs.setups[a] 

33 for VV_aii, D_p in zip(VV_saii, D_sp): 

34 D_ii = unpack_density(D_p) * (dens.nspins / 2) 

35 VV_ii = pawexxvv(data.M_pp, D_ii) 

36 VV_aii[a] = VV_ii 

37 

38 Delta_aiiL = [] 

39 VC_aii: Dict[int, np.ndarray | None] = {} 

40 for a, data in enumerate(wfs.setups): 

41 Delta_aiiL.append(data.Delta_iiL) 

42 if data.X_p is None: 

43 VC_aii[a] = None 

44 else: 

45 VC_aii[a] = unpack_hermitian(data.X_p) 

46 

47 return [PAWThings(VC_aii, VV_aii, Delta_aiiL) 

48 for VV_aii in VV_saii] 

49 

50 

51def python_pawexxvv(M_pp, D_ii): 

52 """PAW correction for valence-valence EXX energy.""" 

53 ni = len(D_ii) 

54 V_ii = np.empty((ni, ni)) 

55 for i1 in range(ni): 

56 for i2 in range(ni): 

57 V = 0.0 

58 for i3 in range(ni): 

59 p13 = packed_index(i1, i3, ni) 

60 for i4 in range(ni): 

61 p24 = packed_index(i2, i4, ni) 

62 V += M_pp[p13, p24] * D_ii[i3, i4] 

63 V_ii[i1, i2] = V 

64 return V_ii 

65 

66 

67pawexxvv = python_pawexxvv 

68 

69 

70if not TYPE_CHECKING: 

71 try: 

72 from _gpaw import pawexxvv # noqa: F811 

73 except ImportError: 

74 import warnings 

75 warnings.warn('Please recompile GPAW binary. Using python ' 

76 'version of pawexxvv instead of faster c version.')