Coverage for gpaw/new/basis.py: 100%
44 statements
« prev ^ index » next coverage.py v7.7.1, created at 2025-07-12 00:18 +0000
« prev ^ index » next coverage.py v7.7.1, created at 2025-07-12 00:18 +0000
1import numpy as np
2from gpaw import GPAW_NO_C_EXTENSION
3from gpaw.core import PWDesc, UGDesc
4from gpaw.kpt_descriptor import KPointDescriptor
5from gpaw.lfc import BasisFunctions
6from gpaw.mpi import serial_comm
7from gpaw.new.brillouin import IBZ
10def create_basis(ibz: IBZ,
11 nspins,
12 pbc_c,
13 grid,
14 setups,
15 dtype,
16 relpos_ac,
17 comm=serial_comm,
18 kpt_comm=serial_comm,
19 band_comm=serial_comm):
20 kd = KPointDescriptor(ibz.bz.kpt_Kc, nspins)
22 kd.ibzk_kc = ibz.kpt_kc
23 kd.weight_k = ibz.weight_k
24 kd.sym_k = ibz.s_K
25 kd.time_reversal_k = ibz.time_reversal_K
26 kd.bz2ibz_k = ibz.bz2ibz_K
27 kd.ibz2bz_k = ibz.ibz2bz_k
28 kd.bz2bz_ks = ibz.bz2bz_Ks
29 kd.nibzkpts = len(ibz)
30 kd.symmetry = ibz.symmetries._old_symmetry
31 kd.set_communicator(kpt_comm)
32 if GPAW_NO_C_EXTENSION:
33 return SimpleBasis(grid, setups, relpos_ac)
34 basis_dtype = complex if \
35 np.issubdtype(dtype, np.complexfloating) else float
36 basis = BasisFunctions(grid._gd,
37 [setup.basis_functions_J for setup in setups],
38 kd,
39 dtype=basis_dtype,
40 cut=True)
41 basis.set_positions(relpos_ac)
42 myM = (basis.Mmax + band_comm.size - 1) // band_comm.size
43 basis.set_matrix_distribution(
44 min(band_comm.rank * myM, basis.Mmax),
45 min((band_comm.rank + 1) * myM, basis.Mmax))
46 return basis
49class SimpleBasis:
50 def __init__(self,
51 grid: UGDesc,
52 setups,
53 relpos_ac):
54 self.grid = grid
55 self.pw = PWDesc(cell=grid.cell,
56 ecut=min(12.5, grid.ekin_max()))
57 self.phit_aIG = self.pw.atom_centered_functions(
58 [setup.basis_functions_J for setup in setups],
59 relpos_ac)
61 def add_to_density(self,
62 nt_sR: np.ndarray,
63 f_asi):
64 nI = sum(f_si.shape[1] for f_si in f_asi.values())
65 c_aiI = self.phit_aIG.empty(nI)
66 c_aiI.data[:] = np.eye(nI)
67 phit_IG = self.pw.zeros(nI)
68 self.phit_aIG.add_to(phit_IG, c_aiI)
69 I = 0
70 for f_si in f_asi.values():
71 for f_s in f_si.T:
72 phit_R = phit_IG[I].ifft(grid=self.grid)
73 nt_sR += f_s[:, np.newaxis, np.newaxis, np.newaxis] * (
74 phit_R.data**2)
75 I += 1