Coverage for gpaw/new/ase_interface.py: 65%
451 statements
« prev ^ index » next coverage.py v7.7.1, created at 2025-07-20 00:19 +0000
« prev ^ index » next coverage.py v7.7.1, created at 2025-07-20 00:19 +0000
1from __future__ import annotations
3import warnings
4from pathlib import Path
5from typing import Any, Callable
7import numpy as np
8from ase import Atoms
9from ase.units import Ha
11from gpaw import __version__
12from gpaw.core import UGArray
13from gpaw.core.arrays import XArrayWithNoData
14from gpaw.dft import GPAW, Parameters
15from gpaw.dos import DOSCalculator
16from gpaw.mpi import broadcast, synchronize_atoms
17from gpaw.new import Timer, trace
18from gpaw.new.calculation import (CalculationModeError, DFTCalculation,
19 ReuseWaveFunctionsError, units)
20from gpaw.new.gpw import GPWFlags, write_gpw
21from gpaw.new.logger import Logger
22from gpaw.new.pw.fulldiag import diagonalize
23from gpaw.new.xc import create_functional
24from gpaw.typing import Array1D, Array2D, Array3D
25from gpaw.utilities import pack_density
26from gpaw.utilities.memory import maxrss
28LOGO = """\
29 ___ ___ ___ _ _ _
30 | | |_ | | | |
31 | | | | | . | | | |
32 |__ | _|___|_____| - {version}
33 |___|_|
34"""
37def write_header(log: Logger, params: Parameters) -> None:
38 from gpaw.io.logger import write_header as header
39 log(LOGO.format(version=__version__))
40 header(log, log.comm)
41 with log.indent('input parameters:'):
42 log(params)
43 with log.indent('\nenvironment variables:'):
44 import gpaw
45 parts = []
46 for name in sorted(gpaw.allowed_envvars):
47 try:
48 value = getattr(gpaw, name)
49 except AttributeError:
50 continue
51 parts.append(f'{name}={value!r}')
52 log(',\n'.join(parts))
55def compare_atoms(a1: Atoms, a2: Atoms) -> set[str]:
56 if a1 is a2:
57 return set()
59 if len(a1.numbers) != len(a2.numbers) or (a1.numbers != a2.numbers).any():
60 return {'numbers'}
62 if (a1.pbc != a2.pbc).any():
63 return {'pbc'}
65 if abs(a1.cell - a2.cell).max() > 0.0:
66 return {'cell'}
68 if abs(a1.positions - a2.positions).max() > 0.0:
69 return {'positions'}
71 return set()
74class ASECalculator:
75 """This is the ASE-calculator frontend for doing a GPAW calculation."""
77 name = 'gpaw'
78 old = False
80 def __init__(self,
81 params: Parameters,
82 *,
83 log: Logger,
84 dft: DFTCalculation | None = None,
85 atoms: Atoms | None = None):
86 self.params = params
87 self.log = log
88 self.comm = log.comm
89 self._dft = dft
90 self._atoms = atoms
91 self.timer = Timer()
92 self.hooks: dict[str, Callable] = {}
93 write_header(log, params)
94 self._wfs_dft = -1, -1
96 @property
97 def dft(self) -> DFTCalculation:
98 if self._dft is None:
99 raise AttributeError
100 return self._dft
102 @property
103 def atoms(self) -> Atoms:
104 if self._atoms is None:
105 raise AttributeError
106 return self._atoms
108 def __repr__(self):
109 return f'ASECalculator({self.params!r})'
111 def iconverge(self, atoms: Atoms | None):
112 """Iterate to self-consistent solution.
114 Will also calculate "cheap" properties: energy, magnetic moments
115 and dipole moment.
116 """
117 if atoms is None:
118 atoms = self.atoms
119 else:
120 synchronize_atoms(atoms, self.comm)
122 converged = True
124 if self._dft is not None:
125 changes = compare_atoms(self.atoms, atoms)
126 if changes & {'numbers', 'pbc', 'cell'}:
127 if 'numbers' not in changes:
128 # Remember magmoms if there are any:
129 magmom_a = self.dft.results.get('magmoms')
130 if magmom_a is not None and magmom_a.any():
131 atoms = atoms.copy()
132 assert atoms is not None # MYPY: why is this needed?
133 atoms.set_initial_magnetic_moments(magmom_a)
135 if changes & {'numbers', 'pbc'}:
136 self._dft = None # start from scratch
137 else:
138 try:
139 self.create_new_calculation_from_old(atoms)
140 except ReuseWaveFunctionsError:
141 self._dft = None # start from scratch
142 else:
143 converged = False
144 changes = set()
146 if self._dft is None:
147 self.create_new_calculation(atoms)
148 converged = False
149 elif changes:
150 self.move_atoms(atoms)
151 converged = False
152 elif not self._dft.results:
153 # Something cleared the results dict
154 converged = False
156 if converged:
157 return
159 if not self.dft.ibzwfs.has_wave_functions():
160 # We have started from a gpw-file without wave functions
161 self.create_new_calculation(atoms)
163 assert self.hooks.keys() <= {'scf_step', 'converged'}
165 with self.timer('SCF'):
166 for ctx in self.dft.iconverge(
167 calculate_forces=self._calculate_forces):
168 yield ctx
169 self.hooks.get('scf_step', lambda ctx: None)(ctx)
171 self.log(f'Converged in {ctx.niter} steps')
173 # Calculate all the cheap things:
174 self.dft.energy()
175 self.dft.dipole()
176 self.dft.magmoms()
178 self.dft.write_converged()
180 self.hooks.get('converged', lambda: None)()
182 def calculate_property(self,
183 atoms: Atoms | None,
184 prop: str) -> Any:
185 """Calculate (if not already calculated) a property.
187 The ``prop`` string must be one of
189 * energy
190 * forces
191 * stress
192 * magmom
193 * magmoms
194 * dipole
195 """
196 for _ in self.iconverge(atoms):
197 pass
199 if prop == 'forces':
200 with self.timer('Forces'):
201 self.dft.forces()
202 elif prop == 'stress':
203 with self.timer('Stress'):
204 self.dft.stress()
205 elif prop not in self.dft.results:
206 raise KeyError('Unknown property:', prop)
208 return self.dft.results[prop] * units[prop]
210 def get_property(self,
211 name: str,
212 atoms: Atoms | None = None,
213 allow_calculation: bool = True) -> Any:
214 if not allow_calculation:
215 if name not in self.dft.results:
216 return None
217 if atoms is None or len(self.check_state(atoms)) == 0:
218 return self.dft.results[name] * units[name]
219 return None
220 if atoms is None:
221 atoms = self.atoms
222 return self.calculate_property(atoms, name)
224 def calculation_required(self, atoms, properties):
225 if any(prop not in self.dft.results for prop in properties):
226 return True
227 return len(self.check_state(atoms)) > 0
229 @property
230 def results(self):
231 if self._dft is None:
232 return {}
233 return {name: value * units[name]
234 for name, value in self.dft.results.items()}
236 @trace
237 def create_new_calculation(self, atoms: Atoms) -> None:
238 with self.timer('Init'):
239 self._dft = DFTCalculation.from_parameters(
240 atoms, self.params, self.comm, self.log)
241 self._atoms = atoms.copy()
243 def create_new_calculation_from_old(self, atoms: Atoms) -> None:
244 with self.timer('Morph'):
245 self._dft = self.dft.new(
246 atoms, self.params, self.log)
247 self._atoms = atoms.copy()
249 def move_atoms(self, atoms):
250 with self.timer('Move'):
251 self._dft = self.dft.move_atoms(atoms)
252 self._atoms = atoms.copy()
254 def _calculate_forces(self) -> Array2D: # units: Ha/Bohr
255 """Helper method for force-convergence criterium."""
256 with self.timer('Forces'):
257 self.dft._calculate_forces()
258 return self.dft.results['forces']
260 def __del__(self):
261 self.log('---')
262 self.timer.write(self.log)
263 try:
264 mib = maxrss() / 1024**2
265 self.log(f'\nMax RSS: {mib:.3f} # MiB')
266 except NameError:
267 pass
268 self.log.close()
270 def get_potential_energy(self,
271 atoms: Atoms | None = None,
272 force_consistent: bool = False) -> float:
273 return self.calculate_property(atoms,
274 'free_energy' if force_consistent else
275 'energy')
277 @trace
278 def get_forces(self, atoms: Atoms | None = None) -> Array2D:
279 return self.calculate_property(atoms, 'forces')
281 @trace
282 def get_stress(self, atoms: Atoms | None = None) -> Array1D:
283 return self.calculate_property(atoms, 'stress')
285 def get_dipole_moment(self, atoms: Atoms | None = None) -> Array1D:
286 return self.calculate_property(atoms, 'dipole')
288 def get_magnetic_moment(self, atoms: Atoms | None = None) -> float:
289 return self.calculate_property(atoms, 'magmom')
291 def get_magnetic_moments(self, atoms: Atoms | None = None) -> Array1D:
292 return self.calculate_property(atoms, 'magmoms')
294 def get_non_collinear_magnetic_moment(self,
295 atoms: Atoms | None = None
296 ) -> Array1D:
297 return self.calculate_property(atoms, 'non_collinear_magmom')
299 def get_non_collinear_magnetic_moments(self,
300 atoms: Atoms | None = None
301 ) -> Array2D:
302 return self.calculate_property(atoms, 'non_collinear_magmoms')
304 def check_state(self, atoms, tol=1e-12):
305 return list(compare_atoms(self.atoms, atoms))
307 def eigenvalues(self):
308 eig_skn = self.dft.ibzwfs.get_all_eigs_and_occs()[0]
309 return broadcast(eig_skn * Ha if self.comm.rank == 0 else None,
310 comm=self.comm)
312 def occupations(self):
313 occ_skn = self.dft.ibzwfs.get_all_eigs_and_occs()[1]
314 return broadcast(occ_skn if self.comm.rank == 0 else None,
315 comm=self.comm)
317 def write(self,
318 filename: str | Path,
319 mode: str = '',
320 precision: str = 'double',
321 include_projections: bool = True) -> None:
322 """Write calculator object to a file.
324 Parameters
325 ----------
326 filename:
327 File to be written.
328 mode:
329 Write mode. Use ``mode='all'``
330 to include wave functions in the file.
331 precision:
332 'double' (the default) or 'single'.
333 include_projections:
334 Use ``include_projections=False`` to not include
335 the PAW-projections.
336 """
337 self.log(f'# Writing to {filename} (mode={mode!r})\n')
339 flags = GPWFlags(include_projections=include_projections,
340 precision=precision, include_wfs=mode == 'all')
341 write_gpw(filename, self.dft, flags=flags)
343 @property
344 def environment(self):
345 return self.dft.pot_calc.environment
347 # Old API:
349 implemented_properties = ['energy', 'free_energy',
350 'forces', 'stress',
351 'dipole', 'magmom', 'magmoms']
353 def icalculate(self, atoms, system_changes=None):
354 yield from self.iconverge(atoms)
356 def new(self, **kwargs) -> ASECalculator:
357 kwargs = {**self.params.todict(), **kwargs}
358 return GPAW(**kwargs)
360 def get_pseudo_wave_function(self, band, kpt=0, spin=None,
361 periodic=False,
362 broadcast=True,
363 pad=True) -> Array3D | None:
364 psit_R = self.dft.wave_functions(n1=band, n2=band + 1,
365 kpt=kpt, spin=spin,
366 periodic=periodic,
367 broadcast=broadcast,
368 _pad=pad)[0]
369 if psit_R is not None:
370 return psit_R.data
371 return None
373 def get_atoms(self):
374 atoms = self.atoms.copy()
375 atoms.calc = self
376 return atoms
378 def get_fermi_level(self) -> float:
379 return self.dft.ibzwfs.fermi_level * Ha
381 def get_fermi_levels(self) -> Array1D:
382 fl = self.dft.ibzwfs.fermi_levels
383 assert fl is not None
384 if len(fl) == 1:
385 raise ValueError('Only one Fermi-level.')
386 return fl * Ha
388 def get_homo_lumo(self, spin: int = None) -> Array1D:
389 return self.dft.ibzwfs.get_homo_lumo(spin) * Ha
391 def get_number_of_electrons(self):
392 density = self.dft.density
393 return (density.nvalence - density.charge +
394 self.dft.pot_calc.environment.charge)
396 def get_number_of_bands(self):
397 return self.dft.ibzwfs.nbands
399 def get_number_of_grid_points(self):
400 return self.dft.density.nt_sR.desc.size
402 def get_effective_potential(self, spin=0, broadcast=True):
403 assert spin == 0
404 vt_R = self.dft.potential.vt_sR[spin]
405 vt_R = vt_R.to_pbc_grid().gather(broadcast=broadcast)
406 return None if vt_R is None else vt_R.data * Ha
408 def get_electrostatic_potential(self):
409 density = self.dft.density
410 potential, _, _ = self.dft.pot_calc.calculate(density)
411 vHt_x = potential.vHt_x
412 if isinstance(vHt_x, UGArray):
413 return vHt_x.gather(broadcast=True).to_pbc_grid().data * Ha
415 grid = self.dft.pot_calc.fine_grid
416 return vHt_x.ifft(grid=grid).gather(broadcast=True).data * Ha
418 def get_atomic_electrostatic_potentials(self):
419 return self.dft.electrostatic_potential().atomic_potentials()
421 def get_electrostatic_corrections(self):
422 return self.dft.electrostatic_potential().atomic_corrections()
424 def get_pseudo_density(self,
425 spin=None,
426 gridrefinement=1,
427 broadcast=True) -> Array3D | None:
428 assert spin is None
429 nt_sr = self.dft.densities().pseudo_densities(
430 grid_refinement=gridrefinement)
431 nt_sr = nt_sr.gather(broadcast=broadcast)
432 return None if nt_sr is None else nt_sr.data.sum(0)
434 def get_all_electron_density(self,
435 spin=None,
436 gridrefinement=1,
437 broadcast=True,
438 skip_core=False):
439 n_sr = self.dft.densities().all_electron_densities(
440 grid_refinement=gridrefinement,
441 skip_core=skip_core)
442 if spin is None:
443 n_sr = n_sr.gather(broadcast=broadcast)
444 return None if n_sr is None else n_sr.data.sum(0)
445 n_r = n_sr[spin].gather(broadcast=broadcast)
446 return None if n_sr is None else n_r.data
448 def get_eigenvalues(self, kpt=0, spin=0, broadcast=True):
449 eig_n = self.dft.ibzwfs.get_eigs_and_occs(k=kpt, s=spin)[0] * Ha
450 if broadcast:
451 if self.comm.rank != 0:
452 eig_n = np.empty(self.dft.ibzwfs.nbands)
453 self.comm.broadcast(eig_n, 0)
454 return eig_n
456 def get_occupation_numbers(self, kpt=0, spin=0, broadcast=True,
457 raw=False):
458 ibzwfs = self.dft.ibzwfs
459 occ_n = ibzwfs.get_eigs_and_occs(k=kpt, s=spin)[1]
460 if not raw:
461 weight = ibzwfs.ibz.weight_k[kpt] * ibzwfs.spin_degeneracy
462 occ_n *= weight
463 if broadcast:
464 if self.comm.rank != 0:
465 occ_n = np.empty(ibzwfs.nbands)
466 self.comm.broadcast(occ_n, 0)
467 return occ_n
469 def get_reference_energy(self):
470 return self.dft.setups.Eref * Ha
472 def get_number_of_iterations(self):
473 return self.dft.scf_loop.niter
475 def get_bz_k_points(self):
476 return self.dft.ibzwfs.ibz.bz.kpt_Kc.copy()
478 def get_ibz_k_points(self):
479 return self.dft.ibzwfs.ibz.kpt_kc.copy()
481 def get_k_point_weights(self):
482 return self.dft.ibzwfs.ibz.weight_k
484 def get_orbital_magnetic_moments(self):
485 """Return the orbital magnetic moment vector for each atom."""
486 density = self.dft.density
487 if density.collinear:
488 raise CalculationModeError(
489 'Calculator is in collinear mode. '
490 'Collinear calculations require spin–orbit '
491 'coupling for nonzero orbital magnetic moments.')
492 if not self.params.soc:
493 warnings.warn('Non-collinear calculation was performed '
494 'without spin–orbit coupling. Orbital '
495 'magnetic moments may not be accurate.')
496 return density.calculate_orbital_magnetic_moments()
498 def calculate(self, atoms, properties=None, system_changes=None):
499 if properties is None:
500 properties = ['energy']
502 for name in properties:
503 self.calculate_property(atoms, name)
505 @property
506 def wfs(self):
507 wfs, dft = self._wfs_dft
508 if dft is not self._dft:
509 from gpaw.new.backwards_compatibility import FakeWFS
510 wfs = FakeWFS(self.dft.ibzwfs,
511 self.dft.density,
512 self.dft.potential,
513 self.dft.setups,
514 self.comm,
515 self.dft.scf_loop.occ_calc,
516 self.dft.scf_loop.hamiltonian,
517 self.atoms,
518 scale_pw_coefs=True)
519 self._wfs_dft = wfs, self._dft
520 return wfs
522 @property
523 def density(self):
524 from gpaw.new.backwards_compatibility import FakeDensity
525 return FakeDensity(self.dft)
527 @property
528 def hamiltonian(self):
529 from gpaw.new.backwards_compatibility import FakeHamiltonian
530 return FakeHamiltonian(
531 self.dft.ibzwfs, self.dft.density, self.dft.potential,
532 self.dft.pot_calc, self.dft.results.get('free_energy'),
533 self.dft.energies._energies['xc'])
535 @property
536 def spos_ac(self):
537 return self.atoms.get_scaled_positions()
539 @property
540 def world(self):
541 return self.comm
543 @property
544 def setups(self):
545 return self.dft.setups
547 @property
548 def initialized(self):
549 return self._dft is not None
551 def get_xc_functional(self):
552 return self.dft.pot_calc.xc.name
554 def get_xc_difference(self, xcparams):
555 """Calculate non-selfconsistent XC-energy difference."""
556 dft = self.dft
557 pot_calc = dft.pot_calc
558 density = dft.density
559 xc = create_functional(xcparams, pot_calc.fine_grid)
560 if xc.type == 'MGGA' and density.taut_sR is None:
561 dft.ibzwfs.make_sure_wfs_are_read_from_gpw_file()
562 if isinstance(dft.ibzwfs.wfs_qs[0][0].psit_nX, XArrayWithNoData):
563 builder = self.params.dft_component_builder(self.atoms,
564 log=dft.log)
565 basis_set = builder.create_basis_set()
566 ibzwfs = builder.create_ibz_wave_functions(
567 basis_set, dft.potential)
568 ibzwfs.fermi_levels = dft.ibzwfs.fermi_levels
569 dft.ibzwfs = ibzwfs
570 dft.scf_loop.update_density_and_potential = False
571 dft.converge()
572 density.update_ked(dft.ibzwfs)
573 exct = pot_calc.calculate_non_selfconsistent_exc(xc, density)
574 dexc = 0.0
575 for a, D_sii in density.D_asii.items():
576 setup = self.setups[a]
577 dexc += xc.calculate_paw_correction(
578 setup, np.array([pack_density(D_ii) for D_ii in D_sii.real]))
579 dexc = dft.ibzwfs.domain_comm.sum_scalar(dexc)
580 return (exct + dexc - dft.energies._energies['xc']) * Ha
582 def diagonalize_full_hamiltonian(self,
583 nbands: int | None = None,
584 scalapack=None,
585 expert: bool | None = None) -> None:
586 if expert is not None:
587 warnings.warn('Ignoring deprecated "expert" argument',
588 DeprecationWarning)
589 dft = self.dft
591 if nbands is None:
592 nbands = min(wfs.array_shape(global_shape=True)[0]
593 for wfs in dft.ibzwfs)
594 nbands = dft.ibzwfs.kpt_comm.min_scalar(nbands)
595 assert isinstance(nbands, int)
597 dft.scf_loop.occ_calc._set_nbands(nbands)
598 ibzwfs = diagonalize(dft.potential,
599 dft.ibzwfs,
600 dft.scf_loop.occ_calc,
601 nbands,
602 dft.density.nvalence + dft.density.charge)
603 dft.ibzwfs = ibzwfs
604 self.params.nbands = ibzwfs.nbands
605 if 'nbands' not in self.params._non_defaults:
606 self.params._non_defaults.append('nbands')
608 def gs_adapter(self):
609 from gpaw.response.groundstate import ResponseGroundStateAdapter
610 return ResponseGroundStateAdapter(self)
612 def fixed_density(self,
613 *,
614 txt='-',
615 update_fermi_level: bool = False,
616 **kwargs) -> ASECalculator:
617 kwargs = {**self.params.todict(), **kwargs}
618 params = Parameters(**kwargs)
619 log = Logger(txt, self.comm)
620 builder = params.dft_component_builder(self.atoms, log=log)
621 basis_set = builder.create_basis_set()
622 dft = self.dft
623 comm1 = dft.ibzwfs.kpt_band_comm
624 comm2 = builder.communicators['D']
625 potential = dft.potential.redist(
626 builder.grid,
627 builder.electrostatic_potential_desc,
628 builder.atomdist,
629 comm1, comm2)
630 density = dft.density.redist(builder.grid,
631 builder.interpolation_desc,
632 builder.atomdist,
633 comm1, comm2)
634 ibzwfs = builder.create_ibz_wave_functions(basis_set, potential)
635 ibzwfs.fermi_levels = dft.ibzwfs.fermi_levels
637 scf_loop = builder.create_scf_loop()
638 scf_loop.update_density_and_potential = False
639 scf_loop.fix_fermi_level = not update_fermi_level
640 for name in ['energy', 'density', 'forces']:
641 scf_loop.convergence.pop(name, None)
643 dft = DFTCalculation(
644 self.atoms, ibzwfs, density, potential,
645 builder.setups,
646 scf_loop,
647 builder.create_potential_calculator(),
648 log,
649 params=params,
650 energies=self.dft.energies)
652 dft.converge()
654 return ASECalculator(params,
655 log=log,
656 dft=dft,
657 atoms=self.atoms)
659 def initialize(self, atoms):
660 self.create_new_calculation(atoms)
662 def converge_wave_functions(self):
663 self.dft.ibzwfs.make_sure_wfs_are_read_from_gpw_file()
665 def get_number_of_spins(self):
666 return self.dft.density.ndensities
668 @property
669 def parameters(self):
670 return self.params
672 def dos(self,
673 soc: bool = False,
674 theta: float = 0.0, # degrees
675 phi: float = 0.0, # degrees
676 shift_fermi_level: bool = True) -> DOSCalculator:
677 """Create DOS-calculator.
679 Default is to ``shift_fermi_level`` to 0.0 eV. For ``soc=True``,
680 angles can be given in degrees.
681 """
682 return DOSCalculator.from_calculator(
683 self, soc=soc,
684 theta=theta, phi=phi,
685 shift_fermi_level=shift_fermi_level)
687 def band_structure(self):
688 """Create band-structure object for plotting."""
689 from ase.spectrum.band_structure import get_band_structure
690 return get_band_structure(calc=self)
692 @property
693 def symmetry(self):
694 return self.dft.ibzwfs.ibz.symmetries._old_symmetry
696 def get_wannier_localization_matrix(self, nbands, dirG, kpoint,
697 nextkpoint, G_I, spin):
698 """Calculate integrals for maximally localized Wannier functions."""
699 from gpaw.new.wannier import get_wannier_integrals
700 grid = self.dft.density.nt_sR.desc
701 k_kc = self.dft.ibzwfs.ibz.bz.kpt_Kc
702 G_c = k_kc[nextkpoint] - k_kc[kpoint] - G_I
704 return get_wannier_integrals(self.dft.ibzwfs,
705 grid,
706 spin, kpoint, nextkpoint, G_c, nbands)
708 def initial_wannier(self, initialwannier, kpointgrid, fixedstates,
709 edf, spin, nbands):
710 from gpaw.new.wannier import initial_wannier
711 return initial_wannier(self.dft.ibzwfs,
712 initialwannier, kpointgrid, fixedstates,
713 edf, spin, nbands)
715 def initialize_positions(self, atoms=None):
716 pass
718 def set(self, eigensolver):
719 assert eigensolver.pop('name') == 'etdm-fdpw'
720 self.dft.scf_loop.eigensolver = self.dft.scf_loop.eigensolver.new(
721 **eigensolver)
723 def todict(self):
724 return self.params.todict()
726 def get_nonselfconsistent_energies(self, type='beefvdw'):
727 from gpaw.xc.bee import BEEFEnsemble
728 if type not in ['beefvdw', 'mbeef', 'mbeefvdw']:
729 raise NotImplementedError('Not implemented for type = %s' % type)
730 # assert self.scf.converged
731 bee = BEEFEnsemble(self)
732 x = bee.create_xc_contributions('exch')
733 c = bee.create_xc_contributions('corr')
734 if type == 'beefvdw':
735 return np.append(x, c)
736 elif type == 'mbeef':
737 return x.flatten()
738 elif type == 'mbeefvdw':
739 return np.append(x.flatten(), c)
741 def get_bz_to_ibz_map(self):
742 """Return indices from BZ to IBZ."""
743 return self.dft.ibzwfs.ibz.bz2ibz_K.copy()