Coverage for gpaw/lcao/local_orbitals.py: 64%
336 statements
« prev ^ index » next coverage.py v7.7.1, created at 2025-07-09 00:21 +0000
« prev ^ index » next coverage.py v7.7.1, created at 2025-07-09 00:21 +0000
1from __future__ import annotations
3from collections import defaultdict
4from typing import Sequence
6import numpy as np
7from ase.data import covalent_radii
8from ase.data.colors import jmol_colors
9from ase.units import Bohr, Hartree
10from gpaw.calculator import GPAW
11from gpaw.lcao.tightbinding import TightBinding # as LCAOTightBinding
12from gpaw.lcao.tools import get_bfi
13from gpaw.typing import Array1D, Array2D, Array4D
14from gpaw.utilities.blas import r2k
15from gpaw.utilities.tools import lowdin, tri2full
16from scipy.linalg import eigh
19def get_subspace(A_MM: Array2D, indices: Sequence[int]):
20 """Get the subspace spanned by the basis function listed in index."""
21 assert A_MM.ndim == 2 and A_MM.shape[0] == A_MM.shape[1]
22 return A_MM.take(indices, 0).take(indices, 1)
25def get_orthonormal_subspace(H_MM: Array2D,
26 S_MM: Array2D,
27 indices: Sequence[int] = None):
28 """Get orthonormal eigenvalues and -vectors of subspace listed in index."""
29 if indices is not None:
30 h_ww = get_subspace(H_MM, indices)
31 s_ww = get_subspace(S_MM, indices)
32 else:
33 h_ww = H_MM
34 s_ww = S_MM
35 eps, v = eigh(h_ww, s_ww)
36 return eps, v
39def subdiagonalize(H_MM: Array2D,
40 S_MM: Array2D,
41 blocks: Sequence[Sequence[int]]):
42 """Subdiagonalize blocks."""
43 nM = len(H_MM)
44 v_MM = np.eye(nM)
45 eps_M = np.zeros(nM)
46 mask_M = np.ones(nM, dtype=int)
47 for block in blocks:
48 eps, v = get_orthonormal_subspace(H_MM, S_MM, block)
49 v_MM[np.ix_(block, block)] = v
50 eps_M[block] = eps
51 mask_M[block] = 0
52 epsx_M = np.ma.masked_array(eps_M, mask=mask_M) # type: ignore
53 return epsx_M, v_MM
56def subdiagonalize_atoms(calc: GPAW,
57 H_MM: Array2D,
58 S_MM: Array2D,
59 atom_indices: int | Sequence[int] | None = None):
60 """Subdiagonalize atomic sub-spaces."""
61 if atom_indices is None:
62 atom_indices = range(len(calc.atoms))
63 if isinstance(atom_indices, int):
64 atom_indices = [atom_indices]
65 block_lists = []
66 for a in atom_indices:
67 M = calc.wfs.basis_functions.M_a[a]
68 block = range(M, M + calc.wfs.setups[a].nao)
69 block_lists.append(block)
70 return subdiagonalize(H_MM, S_MM, block_lists)
73def get_orbitals(calc: GPAW, U_Mw: Array2D, q: int = 0):
74 """Get orbitals from AOs coefficients.
76 Parameters
77 ----------
78 calc : GPAW
79 LCAO calculator
80 U_Mw : array_like
81 LCAO expansion coefficients.
82 """
83 Nw = U_Mw.shape[1]
84 C_wM = np.ascontiguousarray(U_Mw.T).astype(calc.wfs.dtype)
85 w_wG = calc.wfs.gd.zeros(Nw, dtype=calc.wfs.dtype)
86 calc.wfs.basis_functions.lcao_to_grid(C_wM, w_wG, q=q)
87 return w_wG
90def get_xc(calc: GPAW, v_wG: Array4D, P_awi=None):
91 """Get exchange-correlation part of the Hamiltonian."""
92 if calc.density.nt_sg is None:
93 calc.density.interpolate_pseudo_density()
94 nt_sg = calc.density.nt_sg
95 vxct_sg = calc.density.finegd.zeros(calc.wfs.nspins)
96 calc.hamiltonian.xc.calculate(calc.density.finegd, nt_sg, vxct_sg)
97 vxct_G = calc.wfs.gd.empty()
98 calc.hamiltonian.restrict_and_collect(vxct_sg[0], vxct_G)
100 # Integrate pseudo part
101 Nw = len(v_wG)
102 xc_ww = np.empty((Nw, Nw))
103 r2k(0.5 * calc.wfs.gd.dv, v_wG, vxct_G * v_wG, 0.0, xc_ww)
104 tri2full(xc_ww, "L")
106 # Atomic PAW corrections required? XXX
107 if P_awi is not None:
108 raise NotImplementedError(
109 'Atomic PAW corrections not included. '
110 'Have a look at pwf2::get_xc2 for inspiration.')
112 return xc_ww * Hartree
115def get_Fcore():
116 pass
119class BasisTransform:
120 """Class to perform a basis transformation.
122 Attributes
123 ----------
124 U_MM : array_like
125 2-D rotation matrix between 2 basis.
126 indices : array_like, optional
127 1-D array of sub-indices of the new basis.
128 U_Mw : array_like, optional
129 Same as `U_MM` but includes only `indices` of the new basis.
131 Methods
132 -------
133 rotate_matrx(A_MM, keep_rest=False)
134 Rotate a matrix.
135 rotate_projections(P_aMi, keep_rest=False)
136 Rotate PAW atomic projects.
137 rotate_function(P_aMi, keep_rest=False)
138 Rotate PAW atomic projects.
140 """
142 def __init__(self, U_MM: Array2D, indices: Sequence[int] = None) -> None:
143 """
145 Parameters
146 ----------
147 See class docstring
149 """
150 self.U_MM = U_MM
151 self.indices = indices
152 self.U_Mw: Array2D | None
153 if indices is not None:
154 self.U_Mw = np.ascontiguousarray(U_MM[:, indices])
155 else:
156 self.U_Mw = None
158 def get_rotation(self, keep_rest: bool = False):
159 if keep_rest or self.U_Mw is None:
160 return self.U_MM
161 return self.U_Mw
163 def rotate_matrix(self, A_MM: Array2D, keep_rest: bool = False):
164 U_Mx = self.get_rotation(keep_rest)
165 return U_Mx.T.conj() @ A_MM @ U_Mx
167 def rotate_projections(self, P_aMi, keep_rest: bool = False):
168 U_Mx = self.get_rotation(keep_rest)
169 P_awi = {}
170 for a, P_Mi in P_aMi.items():
171 P_awi[a] = np.tensordot(U_Mx, P_Mi, axes=([0], [0]))
172 return P_awi
174 def rotate_function(self, Psi_MG: Array4D, keep_rest: bool = False):
175 U_Mx = self.get_rotation(keep_rest)
176 return np.tensordot(U_Mx, Psi_MG, axes=([0], [0]))
179class EffectiveModel(BasisTransform):
180 """Class for an effective model.
182 See Also
183 --------
184 BasisTranform
187 Methods
188 -------
189 get_static_correction(H_MM: npt.NDArray,
190 S_MM: npt.NDArray,
191 z: complex = 0. + 1e-5j)
192 Hybridization of the effective model with the rest evaluated at `z`.
194 """
196 def __init__(self,
197 U_MM: Array2D,
198 indices: Sequence[int],
199 S_MM: Array2D = None) -> None:
200 """
202 See Also
203 --------
204 BasisTransform
206 Parameters
207 ----------
208 S_MM : array_like, optional
209 2-D LCAO overlap matrix. If provided, the resulting basis
210 is orthogonalized.
211 """
212 if S_MM is not None:
213 lowdin(self.U_Mw, self.rotate_matrix(S_MM))
214 assert self.U_Mw is not None
215 np.testing.assert_allclose(self.rotate_matrix(
216 S_MM), np.eye(self.U_Mw.shape[1]))
217 U_MM = U_MM[:]
218 U_MM = U_MM[:, indices] = self.U_Mw
220 super().__init__(U_MM, indices)
222 def get_static_correction(self,
223 H_MM: Array2D,
224 S_MM: Array2D,
225 z: complex = 0. + 1e-5j):
226 """Get static correction to model Hamiltonian.
228 Parameters
229 ----------
230 H_MM, S_MM : array_like
231 2-D LCAO Hamiltonian and overlap matrices.
232 z : complex
233 Energy with a small positive immaginary shift.
235 """
236 w = self.indices # Alias
237 assert w is not None
239 Hp_MM = self.rotate_matrix(H_MM, keep_rest=True)
240 Sp_MM = self.rotate_matrix(S_MM, keep_rest=True)
241 Up_Mw = Sp_MM[:, w].dot(np.linalg.inv(Sp_MM[np.ix_(w, w)]))
243 H_ww = self.rotate_matrix(H_MM)
244 S_ww = self.rotate_matrix(S_MM)
246 # Coupled
247 G = np.linalg.inv(z * Sp_MM - Hp_MM)
248 # G_inv = np.linalg.inv(rotate_matrix(G, Up_Mw))
249 G_inv = np.linalg.inv(Up_Mw.T.conj() @ G @ Up_Mw)
250 # Uncoupled
251 G0_inv = z * S_ww - H_ww
252 # Hybridization
253 D0 = G0_inv - G_inv
254 return D0.real
256 def __len__(self):
257 return len(self.indices)
260class Subdiagonalization(BasisTransform):
261 """Class to perform a subdiagonalization of the Hamiltonian.
263 Attributes
264 ----------
265 blocks : list of array_like
266 List of blocks to subdiagonalize.
267 H_MM, S_MM : array_like
268 2-D LCAO Hamiltonian and overlap matrices.
269 U_MM : array_like
270 2-D rotation matrix that subdiagonalizes the LCAO Hamiltonian.
271 eps_M : array_like
272 1-D array of local orbital energies.
274 Methods
275 -------
276 group_energies(round=1)
277 Group local orbitals based on energy
278 group_symmetries(cutoff=0.9)
279 Group local orbitals based on symmetries and energy.
280 get_effective_model(indices, ortho=None)
281 Builds and effective model from an array of indices.
283 """
285 def __init__(self,
286 H_MM: Array2D,
287 S_MM: Array2D,
288 blocks: Sequence[Sequence[int]]) -> None:
289 """
291 Parameters
292 ----------
293 See class docstring
295 """
296 self.blocks = blocks
297 self.H_MM = H_MM
298 self.S_MM = S_MM
299 self.eps_M, U_MM = subdiagonalize(
300 self.H_MM, self.S_MM, blocks)
301 super().__init__(U_MM)
302 # Groups of local orbitals with the same symmetry.
303 self.groups: dict[float, list[int]] | None = None
305 def group_energies(self, decimals: int = 1):
306 """Group local orbitals with a similar energy.
308 Parameters
309 ----------
310 decimals : int
311 Round energies to the given number of decimals.
313 """
314 eps = self.eps_M.round(decimals)
315 show = np.where(~eps.mask)[0]
316 groups = defaultdict(list)
317 for index in show:
318 groups[eps[index]].append(index)
320 self.groups = groups # type: ignore[assignment]
321 return self.groups
323 def group_symmetries(self, decimals: int = 1, cutoff: float = 0.9):
324 """Group local orbitals with a similar spatial symmetry and energy.
326 Parameters
327 ----------
328 decimals : int
329 Round energies to the given number of decimals.
330 cutoff : float
331 Sets minimum degree of overlap. Can be any value between 0 and 1.
333 """
334 col_1: list[int] = [] # Keyword.
335 col_2: list[int] = [] # Value.
336 groups = defaultdict(set)
337 blocks = self.blocks
338 # Loop over pair of blocks.
339 for bb1, bb2 in zip(*np.triu_indices(len(blocks), k=1)):
340 b1 = int(bb1)
341 b2 = int(bb2)
342 if len(blocks[b1]) != len(blocks[b2]):
343 # Blocks with different dimensions not compatible.
344 continue
345 U1 = self.U_MM[np.ix_(blocks[b1], blocks[b1])]
346 U2 = self.U_MM[np.ix_(blocks[b2], blocks[b2])]
347 # Compute pair overlap between orbitals in the two blocks.
348 for o1, o2 in np.ndindex(len(blocks[b1]), len(blocks[b1])):
349 v1 = abs(U1[:, o1])
350 v2 = abs(U2[:, o2])
351 o12 = 2 * v1.dot(v2) / (v1.dot(v1) + v2.dot(v2))
352 # Overlap larger than cutoff?
353 if o12 >= cutoff:
354 # Yes.
355 i1 = blocks[b1][o1]
356 i2 = blocks[b2][o2]
357 # Use orbital with minimal index as keyword.
358 i1, i2 = min(i1, i2), max(i1, i2)
359 # Check if `i1` is already present in `col_2` and
360 # use corresponding keyword in col_1 instead.
361 present = False
362 for i, i3 in enumerate(col_2):
363 if i1 == i3:
364 present = True
365 break
366 if present:
367 a1 = col_1[i]
368 else:
369 a1 = i1
370 col_1.append(a1)
371 col_2.append(i2)
372 groups[a1].add(i2)
373 # Try to further group by energy.
374 new: dict[float, list[int]] = defaultdict(list)
375 for k, v in groups.items():
376 v.add(k)
377 new[self.eps_M[k].round(decimals)] += groups[k]
378 self.groups = {k: list(sorted(new[k])) for k in sorted(new)} # groups
379 return self.groups
381 def get_model(self,
382 indices: Sequence[int],
383 ortho: bool = False) -> EffectiveModel:
384 """Extract an effective model from the subdiagonalized space.
386 Parameters
387 ----------
388 indices : array_like
389 1-D array of indices to include in the model from
390 the new basis.
391 ortho : bool, default=False
392 Whether to orthogonalize the model basis.
393 """
394 return EffectiveModel(self.U_MM, indices, self.S_MM if ortho else None)
397class LocalOrbitals(TightBinding):
398 """Local Orbitals.
400 Attributes
401 ----------
402 TODO
404 Methods:
405 --------
406 subdiagonalize(self, symbols=None, blocks=None, groupby='energy')
407 Subdiagonalize the LCAO Hamiltonian.
408 take_model(self, indices=None, minimal=True, cutoff=1e-3, ortho=False)
409 Take an effective model of local orbitals.
410 TODO
412 """
414 def __init__(self, calc: GPAW):
415 self.calc = calc
416 self.gamma = calc.wfs.kd.gamma # Gamma point calculation
417 self.subdiag: Subdiagonalization | None = None
418 self.model: EffectiveModel | None = None
420 if self.gamma:
421 self.calc = calc
422 h = self.calc.hamiltonian
423 wfs = self.calc.wfs
424 kpt = wfs.kpt_u[0]
426 H_MM = wfs.eigensolver.calculate_hamiltonian_matrix(h, wfs, kpt)
427 S_MM = wfs.S_qMM[kpt.q]
428 # XXX Converting to full matrices here
429 tri2full(H_MM)
430 tri2full(S_MM)
431 self.H_NMM = H_MM[None, ...] * Hartree # eV
432 self.S_NMM = S_MM[None, ...]
433 self.N0 = 0
434 else:
435 super().__init__(calc.atoms, calc)
436 # Bloch to real
437 self.H_NMM, self.S_NMM = TightBinding.h_and_s(self)
438 self.H_NMM *= Hartree # eV
439 try:
440 self.N0 = int(np.argwhere(
441 self.R_cN.T.dot(self.R_cN) < 1e-13).flat[0])
442 except Exception as exc:
443 raise RuntimeError(
444 "Must include central unit cell, i.e. R=[0,0,0].") from exc
446 def subdiagonalize(self,
447 symbols: Array1D = None,
448 blocks: Sequence[Sequence[int]] = None,
449 groupby: str = 'energy'):
450 """Subdiagonalize Hamiltonian and overlap matrices.
452 Parameters
453 ----------
454 symbols : array_like, optional
455 Element or elements to subdiagonalize.
456 blocks : list of array_like, optional
457 List of blocks to subdiagonalize.
458 groupby : {'energy,'symmetry'}, optional
459 Group local orbitals based on energy or
460 symmetry and energy. Default is 'energy'.
462 """
463 if symbols is not None:
464 atoms = self.calc.atoms.symbols.search(symbols)
465 blocks = [get_bfi(self.calc, [c]) for c in atoms]
466 if blocks is None:
467 raise RuntimeError("""User must provide either the element(s)
468 or a list of blocks to subdiagonalize.""")
469 self.blocks = blocks
470 self.subdiag = Subdiagonalization(
471 self.H_NMM[self.N0], self.S_NMM[self.N0], blocks)
473 self.groupby(groupby)
475 def groupby(self,
476 method: str = 'energy',
477 decimals: int = 1,
478 cutoff: float = 0.9):
479 """Group local orbitals by symmetry.
481 Parameters
482 ----------
483 method : {'energy,'symmetry'}, optional
484 Group local orbitals based on energy or
485 symmetry and energy. Default is 'energy'.
486 decimals, cutoff : optional
487 Parameters passed to the group methods.
489 """
490 assert self.subdiag is not None
491 if method == 'energy':
492 self.groups = self.subdiag.group_energies(decimals=decimals)
493 elif method == 'symmetry':
494 self.groups = self.subdiag.group_symmetries(
495 decimals=decimals, cutoff=cutoff)
496 else:
497 raise RuntimeError(
498 f"Invalid method type. {method} not in {'energy', 'symmetry'}")
499 # Ensure previous model is invalid.
500 self.model = None
502 def take_model(self,
503 indices: list[int] = None,
504 minimal: bool = True,
505 cutoff: float = 1e-3,
506 ortho: bool = False):
507 """Build an effective model.
509 Parameters
510 ----------
511 indices : array_like
512 1-D array of indices to include in the model
513 from the new basis.
514 minimal : bool, default=True
515 Whether to add (minimal=False) or not (minimal=True)
516 the orbitals with an overlap larger than `cuoff` with any of the
517 orbital specified by `indices`.
518 cutoff : float
519 Cutoff value for the maximum matrix element connecting a group
520 with the minimal model.
521 ortho : bool, default=False
522 Whether to orthogonalize the model.
524 """
525 if self.subdiag is None:
526 raise RuntimeError("""Not yet subdiagonalized.""")
528 eps = self.subdiag.eps_M.round(1)
529 indices_from_input = indices is not None
531 if indices is None:
532 # Find active orbitals with energy closest to Fermi.
534 fermi = round(self.calc.get_fermi_level(), 1)
535 # diffs = [] # Min distance from fermi for each block
536 indices = [] # Min distance index for each block
537 for block in self.blocks:
538 eb = eps[block]
539 ib = np.abs(eb - fermi).argmin()
540 indices.append(block[ib])
541 # diffs.append(abs(eb[ib]))
543 if not minimal:
544 # Find orbitals that connect to active with a matrix
545 # element larger than cutoff
547 # Look at gamma of 1st neighbor
548 H_MM = self.H_NMM[(self.N0 + 1) % len(self.H_NMM)]
549 H_MM = self.subdiag.rotate_matrix(H_MM)
550 # H_MM = dots(self.subdiag.U_MM.T.conj(), H_MM, self.subdiag.U_MM)
552 extend = []
553 for group in self.groups.values():
554 if np.isin(group, indices).any():
555 continue
556 if np.abs(H_MM[np.ix_(indices, group)]).max() > cutoff:
557 extend += group
559 # Expand model
560 indices += extend
562 self.indices = indices
563 self.model = self.subdiag.get_model(indices, ortho=ortho)
565 if self.gamma:
566 H_Nww = self.model.rotate_matrix(self.H_NMM[0])[None, ...]
567 S_Nww = self.model.rotate_matrix(self.S_NMM[0])[None, ...]
569 else:
570 # Bypass parent's LCAO construction.
571 shape = (self.R_cN.shape[1],) + 2 * (len(self.indices),)
572 dtype = self.H_NMM.dtype
573 H_Nww = np.empty(shape, dtype)
574 S_Nww = np.empty(shape, dtype)
576 for N, (H_MM, S_MM) in enumerate(zip(self.H_NMM, self.S_NMM)):
577 H_Nww[N] = self.model.rotate_matrix(H_MM)
578 S_Nww[N] = self.model.rotate_matrix(S_MM)
579 self.H_Nww = H_Nww
580 self.S_Nww = S_Nww
582 if minimal and not indices_from_input:
583 print("Add static correction.")
584 # Add static correction of hybridization to minimal model.
585 self.H_Nww[self.N0] += self.model.get_static_correction(
586 self.H_NMM[self.N0], self.S_NMM[self.N0])
588 def h_and_s(self):
589 # Hartree units.
590 # Bypass TightBinding method.
591 eV = 1 / Hartree
592 return self.H_Nww * eV, self.S_Nww
594 def band_structure(self, path_kc, blochstates=False):
595 # Broute force hack to restore matrices.
596 H_NMM = self.H_NMM
597 S_NMM = self.S_NMM
598 ret = TightBinding.band_structure(self, path_kc, blochstates)
599 self.H_NMM = H_NMM
600 self.S_NMM = S_NMM
601 return ret
603 def get_hamiltonian(self):
604 """Get the Hamiltonian in the home unit cell."""
605 return self.H_Nww[self.N0]
607 def get_overlap(self):
608 """Get the overlap in the home unit cell."""
609 return self.S_Nww[self.N0]
611 def get_orbitals(self, indices):
612 """Get orbitals on the real-space grid."""
613 if self.model is None:
614 basis = self.subdiag
615 else:
616 # Maybe model is orthogonal and subdiag does not know.
617 basis = self.model
618 return get_orbitals(self.calc, basis.U_MM[:, indices])
620 def plot_group(self, group):
621 return plot2D_orbitals(self, self.groups[group])
623 def get_projections(self, q=0):
624 P_aMi = {a: P_aqMi[q] for a, P_aqMi in self.calc.wfs.P_aqMi.items()}
625 return self.model.rotate_projections(P_aMi)
627 def get_xc(self):
628 return get_xc(self.calc, self.get_orbitals(), self.get_projections())
630 def get_Fcore(self):
631 pass
634# Plotting tools
637def get_plane_dirs(plane):
638 """Get normal and in-plane directions for a plane identified
639 by any combination of {'x','y','z'}.
641 Parameters
642 ----------
643 plane : str
644 Pair of chars identifying the plane, e.g. 'xy'.
646 Returns
647 -------
648 norm_dir : int
649 Normal direction
650 norm_dir : list
651 In-plane directions
652 """
653 plane_dirs = ['xyz'.index(i) for i in sorted(plane)]
654 norm_dir = [i for i in [0, 1, 2] if i not in plane_dirs]
655 return norm_dir[0], plane_dirs
658def get_atoms(calc, indices):
659 """Get the list of atoms corresponding to the given indices.
661 Parameters
662 ----------
663 calc : GPAW
664 Calculator
665 indices : array_like
666 List of orbitals for which to retreive the atoms to which they belong.
668 Returns
669 -------
670 a_list : list
671 List of atoms for each index.
672 unique : list
673 Indices of first occurrences of an atom.
674 """
675 atoms = calc.atoms
676 a_list = []
677 unique = []
678 for a in range(len(atoms)):
679 matches = np.where(np.isin(get_bfi(calc, [a]), indices))[0]
680 if len(matches) > 0:
681 a_list += [a] * len(matches)
682 unique.append(len(a_list) - len(matches))
683 return a_list, unique
686def plot2D_orbitals(los, indices, plane='yz'):
687 """Plot a 2D slice of the orbitals.
689 Parameters
690 ----------
691 los : LocalOrbitals
692 Local orbital wrapper
693 indices : array_like
694 List of orbitals to display
695 plane : str, optional
696 Pair of chars identifying the plane, by default 'yz'.
698 Returns
699 -------
700 _type_
701 _description_
702 """
704 import matplotlib.pyplot as plt
706 norm_dir, plane_dirs = get_plane_dirs(plane)
708 calc = los.calc
709 atoms = los.calc.atoms
711 def get_coord(c):
712 return calc.wfs.gd.coords(c, pad=False) * Bohr
714 w_wG = los.get_orbitals(indices)
716 radii = covalent_radii[calc.atoms.numbers]
717 colors = jmol_colors[calc.atoms.numbers]
718 pos = atoms.positions
720 # Take planes at atomic positions.
721 a_list, _ = get_atoms(calc, indices)
722 slice_planes = np.searchsorted(get_coord(norm_dir), pos[a_list, norm_dir])
724 # Take box limited by external atoms plus 4 Ang vacuum.
725 box_lims = [(l - 2, u + 2) for l, u in zip(pos[:, plane_dirs].min(0),
726 pos[:, plane_dirs].max(0))]
727 box_widths = [lims[1] - lims[0] for lims in box_lims]
728 ratio = box_widths[0] / box_widths[1] # Cell ratio
729 num_orbs = len(indices)
730 max_cols = 6
731 nrows = (num_orbs - 1) // max_cols + 1
732 ncols = min(num_orbs, max_cols)
734 figsize = 5
735 fig, axs = plt.subplots(nrows, ncols, figsize=(
736 ncols / nrows * ratio * figsize, figsize))
738 X, Y = np.meshgrid(get_coord(plane_dirs[0]), get_coord(
739 plane_dirs[1]), indexing='ij')
740 take_plane = [slice(None)] * 3
742 it = np.nditer(axs, flags=['refs_ok', 'c_index', 'multi_index'])
743 for _ in it:
744 ax = axs[it.multi_index]
745 w = np.ravel_multi_index(it.multi_index, axs.shape)
746 if w >= num_orbs:
747 ax.axis('off')
748 continue
750 take_plane[norm_dir] = slice_planes[w]
751 C = w_wG[w][tuple(take_plane)]
753 ax.pcolormesh(X, Y, C, cmap='jet', shading='gouraud')
754 ax.set_xlim(box_lims[0])
755 ax.set_ylim(box_lims[1])
756 ax.axis('off')
758 ax.scatter(pos[:, plane_dirs[0]], pos[:, plane_dirs[1]],
759 c=colors, s=radii * 1e3 / (nrows * 2))
760 return fig