Coverage for gpaw/new/basis.py: 100%

44 statements  

« prev     ^ index     » next       coverage.py v7.7.1, created at 2025-07-12 00:18 +0000

1import numpy as np 

2from gpaw import GPAW_NO_C_EXTENSION 

3from gpaw.core import PWDesc, UGDesc 

4from gpaw.kpt_descriptor import KPointDescriptor 

5from gpaw.lfc import BasisFunctions 

6from gpaw.mpi import serial_comm 

7from gpaw.new.brillouin import IBZ 

8 

9 

10def create_basis(ibz: IBZ, 

11 nspins, 

12 pbc_c, 

13 grid, 

14 setups, 

15 dtype, 

16 relpos_ac, 

17 comm=serial_comm, 

18 kpt_comm=serial_comm, 

19 band_comm=serial_comm): 

20 kd = KPointDescriptor(ibz.bz.kpt_Kc, nspins) 

21 

22 kd.ibzk_kc = ibz.kpt_kc 

23 kd.weight_k = ibz.weight_k 

24 kd.sym_k = ibz.s_K 

25 kd.time_reversal_k = ibz.time_reversal_K 

26 kd.bz2ibz_k = ibz.bz2ibz_K 

27 kd.ibz2bz_k = ibz.ibz2bz_k 

28 kd.bz2bz_ks = ibz.bz2bz_Ks 

29 kd.nibzkpts = len(ibz) 

30 kd.symmetry = ibz.symmetries._old_symmetry 

31 kd.set_communicator(kpt_comm) 

32 if GPAW_NO_C_EXTENSION: 

33 return SimpleBasis(grid, setups, relpos_ac) 

34 basis_dtype = complex if \ 

35 np.issubdtype(dtype, np.complexfloating) else float 

36 basis = BasisFunctions(grid._gd, 

37 [setup.basis_functions_J for setup in setups], 

38 kd, 

39 dtype=basis_dtype, 

40 cut=True) 

41 basis.set_positions(relpos_ac) 

42 myM = (basis.Mmax + band_comm.size - 1) // band_comm.size 

43 basis.set_matrix_distribution( 

44 min(band_comm.rank * myM, basis.Mmax), 

45 min((band_comm.rank + 1) * myM, basis.Mmax)) 

46 return basis 

47 

48 

49class SimpleBasis: 

50 def __init__(self, 

51 grid: UGDesc, 

52 setups, 

53 relpos_ac): 

54 self.grid = grid 

55 self.pw = PWDesc(cell=grid.cell, 

56 ecut=min(12.5, grid.ekin_max())) 

57 self.phit_aIG = self.pw.atom_centered_functions( 

58 [setup.basis_functions_J for setup in setups], 

59 relpos_ac) 

60 

61 def add_to_density(self, 

62 nt_sR: np.ndarray, 

63 f_asi): 

64 nI = sum(f_si.shape[1] for f_si in f_asi.values()) 

65 c_aiI = self.phit_aIG.empty(nI) 

66 c_aiI.data[:] = np.eye(nI) 

67 phit_IG = self.pw.zeros(nI) 

68 self.phit_aIG.add_to(phit_IG, c_aiI) 

69 I = 0 

70 for f_si in f_asi.values(): 

71 for f_s in f_si.T: 

72 phit_R = phit_IG[I].ifft(grid=self.grid) 

73 nt_sR += f_s[:, np.newaxis, np.newaxis, np.newaxis] * ( 

74 phit_R.data**2) 

75 I += 1