Coverage for gpaw/basis_data.py: 95%
299 statements
« prev ^ index » next coverage.py v7.7.1, created at 2025-07-19 00:19 +0000
« prev ^ index » next coverage.py v7.7.1, created at 2025-07-19 00:19 +0000
1from __future__ import annotations
3from dataclasses import dataclass, field, replace
4import xml.sax
6import numpy as np
8from gpaw.setup_data import search_for_file
9from gpaw.atom.radialgd import RadialGridDescriptor, radial_grid_descriptor
12_basis_letter2number = {'s': 1, 'd': 2, 't': 3, 'q': 4}
13_basis_number2letter = 'Xsdtq56789'
16def parse_basis_name(name):
17 """Parse any basis type identifier: 'sz', 'dzp', 'qztp', '4z3p', ... """
19 zetacount = _basis_letter2number.get(name[0])
20 if zetacount is None:
21 zetacount = int(name[0])
22 assert name[1] == 'z'
24 if len(name) == 2:
25 polcount = 0
26 elif len(name) == 3:
27 assert name[-1] == 'p'
28 polcount = 1
29 else:
30 assert len(name) == 4 and name[-1] == 'p'
31 polcount = _basis_letter2number.get(name[2])
32 if polcount is None:
33 polcount = int(name[2])
35 return zetacount, polcount
38def parse_basis_filename(filename: str):
39 tokens = filename.split('.')
40 if tokens[-1] == 'gz':
41 tokens.pop()
43 if tokens[-1] != 'basis':
44 raise RuntimeError('Expected <symbol>[.<name>].basis[.gz], '
45 'got {filename!r}')
47 symbol = tokens[0]
48 name = '.'.join(tokens[1:-1])
49 if not name:
50 return symbol, None
51 return symbol, name
54def get_basis_name(zetacount, polarizationcount):
55 zetachar = _basis_number2letter[zetacount]
56 if polarizationcount == 0:
57 return '%sz' % zetachar
58 elif polarizationcount == 1:
59 return '%szp' % zetachar
60 else:
61 polarizationchar = _basis_number2letter[polarizationcount]
62 return f'{zetachar}z{polarizationchar}p'
65@dataclass(eq=False, frozen=True)
66class Basis:
67 symbol: str
68 name: str | None = None
69 rgd: RadialGridDescriptor | None = None
71 bf_j: list[BasisFunction] = field(default_factory=list)
72 ribf_j: list[BasisFunction] = field(default_factory=list)
73 generatorattrs: dict = field(default_factory=dict)
74 generatordata: str = ''
75 filename: str | None = None
77 @classmethod
78 def find(cls, symbol, name, world=None):
79 return cls.read_xml(symbol, name, world=world)
81 @classmethod
82 def read_path(cls, symbol, name, path, world=None):
83 return cls.read_xml(symbol, name, filename=path, world=world)
85 @classmethod
86 def read_xml(cls, symbol, name, filename=None, world=None):
87 parser = BasisSetXMLParser(symbol, name)
88 return parser.parse(filename, world=world)
90 @property
91 def nao(self): # implement as a property so we don't have to
92 # catch all the places where Basis objects are modified without
93 # updating it. (we can do that later)
94 return sum([2 * bf.l + 1 for bf in self.bf_j])
96 @property
97 def nrio(self):
98 return sum([2 * ribf.l + 1 for ribf in self.ribf_j])
100 def get_grid_descriptor(self):
101 return self.rgd
103 def tosplines(self):
104 return [self.rgd.spline(bf.phit_g, bf.rc, bf.l, points=400)
105 for bf in self.bf_j]
107 def ritosplines(self):
108 return [self.rgd.spline(ribf.phit_g, ribf.rc, ribf.l, points=400)
109 for ribf in self.ribf_j]
111 def write_xml(self):
112 """Write basis functions to file.
114 Writes all basis functions in the given list of basis functions
115 to the file "<symbol>.<name>.basis".
116 """
117 if self.name is None:
118 filename = '%s.basis' % self.symbol
119 else:
120 filename = f'{self.symbol}.{self.name}.basis'
122 with open(filename, 'w') as fd:
123 self.write_to(fd)
125 def write_to(self, fd):
126 write = fd.write
127 write('<paw_basis version="0.1">\n')
129 generatorattrs = ' '.join([f'{key}="{value}"'
130 for key, value
131 in self.generatorattrs.items()])
132 write(' <generator %s>' % generatorattrs)
133 for line in self.generatordata.split('\n'):
134 write('\n ' + line)
135 write('\n </generator>\n')
137 write(' ' + self.rgd.xml())
139 # Write both the basis functions and auxiliary ones
140 for bfs in [self.bf_j, self.ribf_j]:
141 for bf in bfs:
142 write(bf.xml(indentation=' '))
144 write('</paw_basis>\n')
146 def reduce(self, name):
147 """Reduce the number of basis functions and return new Basis.
149 Example: basis.reduce('sz') will remove all non single-zeta
150 and polarization functions."""
152 zeta, pol = parse_basis_name(name)
153 newbf_j = []
154 N = {}
155 p = 0
156 for bf in self.bf_j:
157 if 'polarization' in bf.type:
158 if p < pol:
159 newbf_j.append(bf)
160 p += 1
161 else:
162 nl = (int(bf.type[0]), 'spdf'.index(bf.type[1]))
163 if nl not in N:
164 N[nl] = 0
165 if N[nl] < zeta:
166 newbf_j.append(bf)
167 N[nl] += 1
168 return replace(self, bf_j=newbf_j)
170 def get_description(self):
171 title = 'LCAO basis set for %s:' % self.symbol
172 if self.name is not None:
173 name = 'Name: ' + self.name
174 else:
175 name = 'This basis set does not have a name'
176 if self.filename is None:
177 fileinfo = 'This basis set was not loaded from a file'
178 else:
179 fileinfo = 'File: ' + self.filename
180 nj = len(self.bf_j)
181 count1 = 'Number of radial functions: %d' % nj
182 count2 = 'Number of spherical harmonics: %d' % self.nao
184 bf_lines = []
185 for bf in self.bf_j:
186 line = ' l=%d, rc=%.4f Bohr: %s' % (bf.l, bf.rc, bf.type)
187 bf_lines.append(line)
189 lines = [title, name, fileinfo, count1, count2]
190 lines.extend(bf_lines)
191 lines.append(f'Number of RI-basis functions {self.nrio}')
192 for ribf in self.ribf_j:
193 lines.append('l=%d %s' % (ribf.l, ribf.type))
195 return '\n '.join(lines)
198@dataclass
199class BasisFunction:
200 """Encapsulates various basis function data."""
202 n: int | None = None
203 l: int | None = None
204 rc: float | None = None
205 phit_g: np.ndarray | None = None
206 type: str | None = None
208 @property
209 def name(self):
210 if self.n is None or self.n < 0:
211 return f'l={self.l} {self.type}'
213 lname = 'spdf'[self.l]
214 return f'{self.n}{lname} {type}'
216 def __repr__(self, gridid=None):
217 txt = '<basis_function '
218 if self.n is not None:
219 txt += 'n="%d" ' % self.n
220 txt += (f'l="{self.l}" rc="{self.rc}" type="{self.type}"')
221 if gridid is not None:
222 txt += ' grid="%s"' % gridid
223 return txt + '>'
225 def xml(self, gridid='grid1', indentation=''):
226 txt = indentation + self.__repr__(gridid) + '\n'
227 txt += indentation + ' ' + ' '.join(str(x) for x in self.phit_g)
228 txt += '\n' + indentation + '</basis_function>\n'
229 return txt
232class BasisSetXMLParser(xml.sax.handler.ContentHandler):
233 def __init__(self, symbol, name):
234 super().__init__()
235 self.symbol = symbol
236 self.name = name
238 self.type = None
239 self.rc = None
240 self.data = None
241 self.l = None
242 self.bf_j = []
243 self.ribf_j = []
245 self._dct = {}
247 def parse(self, filename=None, world=None):
248 """Read from symbol.name.basis file.
250 Example of filename: N.dzp.basis. Use sz(dzp) to read
251 the sz-part from the N.dzp.basis file."""
252 from gpaw.setup_data import read_maybe_unzipping
254 if '(' in self.name:
255 assert self.name.endswith(')')
256 reduced, name = self.name.split('(')
257 name = name[:-1]
258 else:
259 name = self.name
260 reduced = None
261 fullname = f'{self.symbol}.{name}.basis'
262 if filename is None:
263 filename, source = search_for_file(fullname, world=world)
264 else:
265 source = read_maybe_unzipping(filename)
267 self.filename = filename
268 self.data = None
269 xml.sax.parseString(source, self)
271 basis = Basis(symbol=self.symbol, name=self.name, filename=filename,
272 bf_j=[*self.bf_j], ribf_j=[*self.ribf_j],
273 **self._dct)
275 if reduced:
276 basis = basis.reduce(reduced)
278 return basis
280 def startElement(self, name, attrs):
281 dct = self._dct
282 # For name == 'paw_basis' we can save attrs['version'], too.
283 if name == 'generator':
284 dct['generatorattrs'] = dict(attrs)
285 self.data = []
286 elif name == 'radial_grid':
287 dct['rgd'] = radial_grid_descriptor(**attrs)
288 elif name == 'basis_function':
289 self.l = int(attrs['l'])
290 self.rc = float(attrs['rc'])
291 self.type = attrs.get('type')
292 self.data = []
293 if 'n' in attrs:
294 self.n = int(attrs['n'])
295 elif self.type[0].isdigit():
296 self.n = int(self.type[0])
297 else:
298 self.n = None
300 def characters(self, data):
301 if self.data is not None:
302 self.data.append(data)
304 def endElement(self, name):
305 if name == 'basis_function':
306 phit_g = np.array([float(x) for x in ''.join(self.data).split()])
307 bf = BasisFunction(self.n, self.l, self.rc, phit_g, self.type)
308 # Also auxiliary basis functions are added here. They are
309 # distinguished by their type='auxiliary'.
311 if bf.type == 'auxiliary':
312 self.ribf_j.append(bf)
313 else:
314 self.bf_j.append(bf)
316 elif name == 'generator':
317 self._dct['generatordata'] = ''.join([line for line in self.data])
320class BasisPlotter:
321 def __init__(self, premultiply=True, normalize=False,
322 show=False, save=False, ext='png'):
323 self.premultiply = premultiply
324 self.show = show
325 self.save = save
326 self.ext = ext
327 self.default_filename = '%(symbol)s.%(name)s.' + ext
329 self.title = 'Basis functions: %(symbol)s %(name)s'
330 self.xlabel = 'radius [Bohr]'
331 ylabel = r'\Phi(r)'
332 if premultiply:
333 ylabel = 'r' + ylabel
334 self.ylabel = f'${ylabel}$'
336 self.normalize = normalize
338 def plot(self, basis, filename=None, ax=None, **plot_args):
339 if ax is None:
340 from matplotlib import pyplot as plt
342 ax = plt.figure().gca()
344 if plot_args is None:
345 plot_args = {}
346 r_g = basis.rgd.r_g
348 print('Element :', basis.symbol)
349 print('Name :', basis.name)
350 print()
351 print('Basis functions')
352 print('---------------')
354 norm_j = []
355 for j, bf in enumerate(basis.bf_j):
356 ng = len(bf.phit_g)
357 rphit_g = r_g[:ng] * bf.phit_g
358 norm = (rphit_g**2 * basis.rgd.dr_g[:ng]).sum()
359 norm_j.append(norm)
360 print(bf.type, '[norm=%0.4f]' % norm)
362 print()
363 print('Generator')
364 for key, item in basis.generatorattrs.items():
365 print(' ', key, ':', item)
366 print()
367 print('Generator data')
368 print(basis.generatordata)
370 if self.premultiply:
371 factor = r_g
372 else:
373 factor = np.ones_like(r_g)
375 dashes_l = [(), (6, 3), (4, 1, 1, 1), (1, 1)]
377 for norm, bf in zip(norm_j, basis.bf_j):
378 ng = len(bf.phit_g)
379 y_g = bf.phit_g * factor[:ng]
380 if self.normalize:
381 y_g /= norm
382 ax.plot(r_g[:ng], y_g, label=bf.type[:12],
383 dashes=dashes_l[bf.l], lw=2,
384 **plot_args)
385 axis = ax.axis()
386 rc = max([bf.rc for bf in basis.bf_j])
387 newaxis = [0., rc, axis[2], axis[3]]
388 ax.axis(newaxis)
389 ax.legend()
390 ax.set_title(self.title % basis.__dict__)
391 ax.set_xlabel(self.xlabel)
392 ax.set_ylabel(self.ylabel)
394 if filename is None:
395 filename = self.default_filename
396 if self.save:
397 ax.get_figure().savefig(filename % basis.__dict__)
399 if self.show:
400 plt.show()
402 return ax
405class CLICommand:
406 """Plot basis set from FILE."""
408 @staticmethod
409 def add_arguments(parser):
410 parser.add_argument('file', metavar='FILE')
411 parser.add_argument(
412 '--write', metavar='FILE',
413 help='write plot to file inferring format from file extension.')
415 @staticmethod
416 def run(args):
417 from pathlib import Path
418 import matplotlib.pyplot as plt
419 path = Path(args.file)
421 # It is not particularly beautiful that we get the symbol and type
422 # from the filename. It would be better for that information
423 # to be stored in the file, but it isn't.
424 symbol, name = parse_basis_filename(path.name)
425 basis = Basis.read_path(symbol, name, path=path)
427 plotter = BasisPlotter()
428 ax = plotter.plot(basis)
430 if args.write:
431 ax.get_figure().savefig(args.write)
432 else:
433 plt.show()