Coverage for gpaw/new/pwfd/move_wfs.py: 91%

23 statements  

« prev     ^ index     » next       coverage.py v7.7.1, created at 2025-07-14 00:18 +0000

1from __future__ import annotations 

2 

3import numpy as np 

4from gpaw.core.arrays import DistributedArrays as XArray 

5from gpaw.core.atom_arrays import AtomArrays 

6from gpaw.setup import Setups 

7 

8 

9def move_wave_functions(oldrelpos_ac: np.ndarray, 

10 newrelpos_ac: np.ndarray, 

11 P_ani: AtomArrays, 

12 psit_nX: XArray, 

13 setups: Setups) -> None: 

14 """Move wavefunctions with atoms according to PAW basis 

15 

16 Wavefunctions are approximated as::: 

17 

18 ~ _ -- ~a _ ~a ~ 

19 ψ(r) = > φ (r) <p | ψ >, 

20 n -- i i n 

21 ai 

22 

23 where i runs over the bound partial-waves only. 

24 This quantity is then subtracted and re-added at the new 

25 positions. 

26 """ 

27 desc = psit_nX.desc 

28 atomdist = P_ani.layout.atomdist 

29 

30 # Create partial-wave ACF object (b denotes bound states): 

31 phit_abX = desc.atom_centered_functions( 

32 [setup.get_partial_waves_for_atomic_orbitals() for setup in setups], 

33 oldrelpos_ac, 

34 atomdist=atomdist, 

35 cut=True, 

36 xp=psit_nX.xp) 

37 

38 P_anb = phit_abX.empty(psit_nX.dims, comm=psit_nX.comm) 

39 for a, P_nb in P_anb.items(): 

40 P_nb[:] = -P_ani[a][:, :P_nb.shape[1]] 

41 

42 # Subtract partial wave expansion: 

43 phit_abX.add_to(psit_nX, P_anb) 

44 

45 if desc.dtype == complex: 

46 disp_ac = (newrelpos_ac - oldrelpos_ac).round() 

47 phase_a = np.exp(2j * np.pi * disp_ac @ desc.kpt_c) 

48 for a, P_nb in P_anb.items(): 

49 P_nb *= -phase_a[a] 

50 else: 

51 P_anb.data *= -1.0 

52 

53 # Add partial wave expansion at new positions: 

54 atomdist2 = phit_abX.move(newrelpos_ac, atomdist) 

55 if atomdist2 is not atomdist: 

56 P_anb = P_anb.moved(atomdist2) 

57 phit_abX.add_to(psit_nX, P_anb)