Coverage for gpaw/basis_data.py: 95%

299 statements  

« prev     ^ index     » next       coverage.py v7.7.1, created at 2025-07-19 00:19 +0000

1from __future__ import annotations 

2 

3from dataclasses import dataclass, field, replace 

4import xml.sax 

5 

6import numpy as np 

7 

8from gpaw.setup_data import search_for_file 

9from gpaw.atom.radialgd import RadialGridDescriptor, radial_grid_descriptor 

10 

11 

12_basis_letter2number = {'s': 1, 'd': 2, 't': 3, 'q': 4} 

13_basis_number2letter = 'Xsdtq56789' 

14 

15 

16def parse_basis_name(name): 

17 """Parse any basis type identifier: 'sz', 'dzp', 'qztp', '4z3p', ... """ 

18 

19 zetacount = _basis_letter2number.get(name[0]) 

20 if zetacount is None: 

21 zetacount = int(name[0]) 

22 assert name[1] == 'z' 

23 

24 if len(name) == 2: 

25 polcount = 0 

26 elif len(name) == 3: 

27 assert name[-1] == 'p' 

28 polcount = 1 

29 else: 

30 assert len(name) == 4 and name[-1] == 'p' 

31 polcount = _basis_letter2number.get(name[2]) 

32 if polcount is None: 

33 polcount = int(name[2]) 

34 

35 return zetacount, polcount 

36 

37 

38def parse_basis_filename(filename: str): 

39 tokens = filename.split('.') 

40 if tokens[-1] == 'gz': 

41 tokens.pop() 

42 

43 if tokens[-1] != 'basis': 

44 raise RuntimeError('Expected <symbol>[.<name>].basis[.gz], ' 

45 'got {filename!r}') 

46 

47 symbol = tokens[0] 

48 name = '.'.join(tokens[1:-1]) 

49 if not name: 

50 return symbol, None 

51 return symbol, name 

52 

53 

54def get_basis_name(zetacount, polarizationcount): 

55 zetachar = _basis_number2letter[zetacount] 

56 if polarizationcount == 0: 

57 return '%sz' % zetachar 

58 elif polarizationcount == 1: 

59 return '%szp' % zetachar 

60 else: 

61 polarizationchar = _basis_number2letter[polarizationcount] 

62 return f'{zetachar}z{polarizationchar}p' 

63 

64 

65@dataclass(eq=False, frozen=True) 

66class Basis: 

67 symbol: str 

68 name: str | None = None 

69 rgd: RadialGridDescriptor | None = None 

70 

71 bf_j: list[BasisFunction] = field(default_factory=list) 

72 ribf_j: list[BasisFunction] = field(default_factory=list) 

73 generatorattrs: dict = field(default_factory=dict) 

74 generatordata: str = '' 

75 filename: str | None = None 

76 

77 @classmethod 

78 def find(cls, symbol, name, world=None): 

79 return cls.read_xml(symbol, name, world=world) 

80 

81 @classmethod 

82 def read_path(cls, symbol, name, path, world=None): 

83 return cls.read_xml(symbol, name, filename=path, world=world) 

84 

85 @classmethod 

86 def read_xml(cls, symbol, name, filename=None, world=None): 

87 parser = BasisSetXMLParser(symbol, name) 

88 return parser.parse(filename, world=world) 

89 

90 @property 

91 def nao(self): # implement as a property so we don't have to 

92 # catch all the places where Basis objects are modified without 

93 # updating it. (we can do that later) 

94 return sum([2 * bf.l + 1 for bf in self.bf_j]) 

95 

96 @property 

97 def nrio(self): 

98 return sum([2 * ribf.l + 1 for ribf in self.ribf_j]) 

99 

100 def get_grid_descriptor(self): 

101 return self.rgd 

102 

103 def tosplines(self): 

104 return [self.rgd.spline(bf.phit_g, bf.rc, bf.l, points=400) 

105 for bf in self.bf_j] 

106 

107 def ritosplines(self): 

108 return [self.rgd.spline(ribf.phit_g, ribf.rc, ribf.l, points=400) 

109 for ribf in self.ribf_j] 

110 

111 def write_xml(self): 

112 """Write basis functions to file. 

113 

114 Writes all basis functions in the given list of basis functions 

115 to the file "<symbol>.<name>.basis". 

116 """ 

117 if self.name is None: 

118 filename = '%s.basis' % self.symbol 

119 else: 

120 filename = f'{self.symbol}.{self.name}.basis' 

121 

122 with open(filename, 'w') as fd: 

123 self.write_to(fd) 

124 

125 def write_to(self, fd): 

126 write = fd.write 

127 write('<paw_basis version="0.1">\n') 

128 

129 generatorattrs = ' '.join([f'{key}="{value}"' 

130 for key, value 

131 in self.generatorattrs.items()]) 

132 write(' <generator %s>' % generatorattrs) 

133 for line in self.generatordata.split('\n'): 

134 write('\n ' + line) 

135 write('\n </generator>\n') 

136 

137 write(' ' + self.rgd.xml()) 

138 

139 # Write both the basis functions and auxiliary ones 

140 for bfs in [self.bf_j, self.ribf_j]: 

141 for bf in bfs: 

142 write(bf.xml(indentation=' ')) 

143 

144 write('</paw_basis>\n') 

145 

146 def reduce(self, name): 

147 """Reduce the number of basis functions and return new Basis. 

148 

149 Example: basis.reduce('sz') will remove all non single-zeta 

150 and polarization functions.""" 

151 

152 zeta, pol = parse_basis_name(name) 

153 newbf_j = [] 

154 N = {} 

155 p = 0 

156 for bf in self.bf_j: 

157 if 'polarization' in bf.type: 

158 if p < pol: 

159 newbf_j.append(bf) 

160 p += 1 

161 else: 

162 nl = (int(bf.type[0]), 'spdf'.index(bf.type[1])) 

163 if nl not in N: 

164 N[nl] = 0 

165 if N[nl] < zeta: 

166 newbf_j.append(bf) 

167 N[nl] += 1 

168 return replace(self, bf_j=newbf_j) 

169 

170 def get_description(self): 

171 title = 'LCAO basis set for %s:' % self.symbol 

172 if self.name is not None: 

173 name = 'Name: ' + self.name 

174 else: 

175 name = 'This basis set does not have a name' 

176 if self.filename is None: 

177 fileinfo = 'This basis set was not loaded from a file' 

178 else: 

179 fileinfo = 'File: ' + self.filename 

180 nj = len(self.bf_j) 

181 count1 = 'Number of radial functions: %d' % nj 

182 count2 = 'Number of spherical harmonics: %d' % self.nao 

183 

184 bf_lines = [] 

185 for bf in self.bf_j: 

186 line = ' l=%d, rc=%.4f Bohr: %s' % (bf.l, bf.rc, bf.type) 

187 bf_lines.append(line) 

188 

189 lines = [title, name, fileinfo, count1, count2] 

190 lines.extend(bf_lines) 

191 lines.append(f'Number of RI-basis functions {self.nrio}') 

192 for ribf in self.ribf_j: 

193 lines.append('l=%d %s' % (ribf.l, ribf.type)) 

194 

195 return '\n '.join(lines) 

196 

197 

198@dataclass 

199class BasisFunction: 

200 """Encapsulates various basis function data.""" 

201 

202 n: int | None = None 

203 l: int | None = None 

204 rc: float | None = None 

205 phit_g: np.ndarray | None = None 

206 type: str | None = None 

207 

208 @property 

209 def name(self): 

210 if self.n is None or self.n < 0: 

211 return f'l={self.l} {self.type}' 

212 

213 lname = 'spdf'[self.l] 

214 return f'{self.n}{lname} {type}' 

215 

216 def __repr__(self, gridid=None): 

217 txt = '<basis_function ' 

218 if self.n is not None: 

219 txt += 'n="%d" ' % self.n 

220 txt += (f'l="{self.l}" rc="{self.rc}" type="{self.type}"') 

221 if gridid is not None: 

222 txt += ' grid="%s"' % gridid 

223 return txt + '>' 

224 

225 def xml(self, gridid='grid1', indentation=''): 

226 txt = indentation + self.__repr__(gridid) + '\n' 

227 txt += indentation + ' ' + ' '.join(str(x) for x in self.phit_g) 

228 txt += '\n' + indentation + '</basis_function>\n' 

229 return txt 

230 

231 

232class BasisSetXMLParser(xml.sax.handler.ContentHandler): 

233 def __init__(self, symbol, name): 

234 super().__init__() 

235 self.symbol = symbol 

236 self.name = name 

237 

238 self.type = None 

239 self.rc = None 

240 self.data = None 

241 self.l = None 

242 self.bf_j = [] 

243 self.ribf_j = [] 

244 

245 self._dct = {} 

246 

247 def parse(self, filename=None, world=None): 

248 """Read from symbol.name.basis file. 

249 

250 Example of filename: N.dzp.basis. Use sz(dzp) to read 

251 the sz-part from the N.dzp.basis file.""" 

252 from gpaw.setup_data import read_maybe_unzipping 

253 

254 if '(' in self.name: 

255 assert self.name.endswith(')') 

256 reduced, name = self.name.split('(') 

257 name = name[:-1] 

258 else: 

259 name = self.name 

260 reduced = None 

261 fullname = f'{self.symbol}.{name}.basis' 

262 if filename is None: 

263 filename, source = search_for_file(fullname, world=world) 

264 else: 

265 source = read_maybe_unzipping(filename) 

266 

267 self.filename = filename 

268 self.data = None 

269 xml.sax.parseString(source, self) 

270 

271 basis = Basis(symbol=self.symbol, name=self.name, filename=filename, 

272 bf_j=[*self.bf_j], ribf_j=[*self.ribf_j], 

273 **self._dct) 

274 

275 if reduced: 

276 basis = basis.reduce(reduced) 

277 

278 return basis 

279 

280 def startElement(self, name, attrs): 

281 dct = self._dct 

282 # For name == 'paw_basis' we can save attrs['version'], too. 

283 if name == 'generator': 

284 dct['generatorattrs'] = dict(attrs) 

285 self.data = [] 

286 elif name == 'radial_grid': 

287 dct['rgd'] = radial_grid_descriptor(**attrs) 

288 elif name == 'basis_function': 

289 self.l = int(attrs['l']) 

290 self.rc = float(attrs['rc']) 

291 self.type = attrs.get('type') 

292 self.data = [] 

293 if 'n' in attrs: 

294 self.n = int(attrs['n']) 

295 elif self.type[0].isdigit(): 

296 self.n = int(self.type[0]) 

297 else: 

298 self.n = None 

299 

300 def characters(self, data): 

301 if self.data is not None: 

302 self.data.append(data) 

303 

304 def endElement(self, name): 

305 if name == 'basis_function': 

306 phit_g = np.array([float(x) for x in ''.join(self.data).split()]) 

307 bf = BasisFunction(self.n, self.l, self.rc, phit_g, self.type) 

308 # Also auxiliary basis functions are added here. They are 

309 # distinguished by their type='auxiliary'. 

310 

311 if bf.type == 'auxiliary': 

312 self.ribf_j.append(bf) 

313 else: 

314 self.bf_j.append(bf) 

315 

316 elif name == 'generator': 

317 self._dct['generatordata'] = ''.join([line for line in self.data]) 

318 

319 

320class BasisPlotter: 

321 def __init__(self, premultiply=True, normalize=False, 

322 show=False, save=False, ext='png'): 

323 self.premultiply = premultiply 

324 self.show = show 

325 self.save = save 

326 self.ext = ext 

327 self.default_filename = '%(symbol)s.%(name)s.' + ext 

328 

329 self.title = 'Basis functions: %(symbol)s %(name)s' 

330 self.xlabel = 'radius [Bohr]' 

331 ylabel = r'\Phi(r)' 

332 if premultiply: 

333 ylabel = 'r' + ylabel 

334 self.ylabel = f'${ylabel}$' 

335 

336 self.normalize = normalize 

337 

338 def plot(self, basis, filename=None, ax=None, **plot_args): 

339 if ax is None: 

340 from matplotlib import pyplot as plt 

341 

342 ax = plt.figure().gca() 

343 

344 if plot_args is None: 

345 plot_args = {} 

346 r_g = basis.rgd.r_g 

347 

348 print('Element :', basis.symbol) 

349 print('Name :', basis.name) 

350 print() 

351 print('Basis functions') 

352 print('---------------') 

353 

354 norm_j = [] 

355 for j, bf in enumerate(basis.bf_j): 

356 ng = len(bf.phit_g) 

357 rphit_g = r_g[:ng] * bf.phit_g 

358 norm = (rphit_g**2 * basis.rgd.dr_g[:ng]).sum() 

359 norm_j.append(norm) 

360 print(bf.type, '[norm=%0.4f]' % norm) 

361 

362 print() 

363 print('Generator') 

364 for key, item in basis.generatorattrs.items(): 

365 print(' ', key, ':', item) 

366 print() 

367 print('Generator data') 

368 print(basis.generatordata) 

369 

370 if self.premultiply: 

371 factor = r_g 

372 else: 

373 factor = np.ones_like(r_g) 

374 

375 dashes_l = [(), (6, 3), (4, 1, 1, 1), (1, 1)] 

376 

377 for norm, bf in zip(norm_j, basis.bf_j): 

378 ng = len(bf.phit_g) 

379 y_g = bf.phit_g * factor[:ng] 

380 if self.normalize: 

381 y_g /= norm 

382 ax.plot(r_g[:ng], y_g, label=bf.type[:12], 

383 dashes=dashes_l[bf.l], lw=2, 

384 **plot_args) 

385 axis = ax.axis() 

386 rc = max([bf.rc for bf in basis.bf_j]) 

387 newaxis = [0., rc, axis[2], axis[3]] 

388 ax.axis(newaxis) 

389 ax.legend() 

390 ax.set_title(self.title % basis.__dict__) 

391 ax.set_xlabel(self.xlabel) 

392 ax.set_ylabel(self.ylabel) 

393 

394 if filename is None: 

395 filename = self.default_filename 

396 if self.save: 

397 ax.get_figure().savefig(filename % basis.__dict__) 

398 

399 if self.show: 

400 plt.show() 

401 

402 return ax 

403 

404 

405class CLICommand: 

406 """Plot basis set from FILE.""" 

407 

408 @staticmethod 

409 def add_arguments(parser): 

410 parser.add_argument('file', metavar='FILE') 

411 parser.add_argument( 

412 '--write', metavar='FILE', 

413 help='write plot to file inferring format from file extension.') 

414 

415 @staticmethod 

416 def run(args): 

417 from pathlib import Path 

418 import matplotlib.pyplot as plt 

419 path = Path(args.file) 

420 

421 # It is not particularly beautiful that we get the symbol and type 

422 # from the filename. It would be better for that information 

423 # to be stored in the file, but it isn't. 

424 symbol, name = parse_basis_filename(path.name) 

425 basis = Basis.read_path(symbol, name, path=path) 

426 

427 plotter = BasisPlotter() 

428 ax = plotter.plot(basis) 

429 

430 if args.write: 

431 ax.get_figure().savefig(args.write) 

432 else: 

433 plt.show()