Coverage for gpaw/new/constraints.py: 98%
40 statements
« prev ^ index » next coverage.py v7.7.1, created at 2025-07-20 00:19 +0000
« prev ^ index » next coverage.py v7.7.1, created at 2025-07-20 00:19 +0000
1from __future__ import annotations
3import numpy as np
4from ase.units import Ha
5from gpaw.new.extensions import Extension
6from gpaw.typing import Array1D, Array3D, Vector
9class SpinDirectionConstraint(Extension):
10 name = 'spin_direction_constraint'
12 def __init__(self,
13 constraint: dict[int, Vector],
14 penalty: float = 0.8):
15 """Spin-direction constraint.
17 Parameters
18 ==========
19 constraint:
20 Dictionary mapping atom numbers to directions.
21 Example: ``{0: (0, 0, 1), 1: (1, 0, 0), ...}``.
22 penalty:
23 Strength of penalty term in eV.
24 """
25 self.constraint = {a: np.array(u_v) / np.linalg.norm(u_v)
26 for a, u_v in constraint.items()}
27 self.penalty = penalty / Ha
29 def todict(self):
30 return dict(constraint=dict((a, u_v.tolist())
31 for a, u_v in self.constraint.items()),
32 penalty=self.penalty * Ha)
34 def update_non_local_hamiltonian(self,
35 D_sii,
36 setup,
37 atom_index,
38 dH_sii) -> float:
39 eL, dHL_vii = self.calculate(D_sii[1:4].real, atom_index,
40 setup.l_j, setup.N0_q)
41 dH_sii[1:4] += dHL_vii
42 return eL
44 def calculate(self,
45 M_vii: Array3D,
46 a: int,
47 l_j: Array1D,
48 N0_q: Array1D,
49 return_energy: bool = False):
50 dHL_vii = np.zeros_like(M_vii)
52 if a not in self.constraint:
53 return 0.0, dHL_vii
54 u_v = self.constraint[a]
56 smm_v = np.zeros(3) # Spin magnetic moment
58 nj = len(l_j)
59 i1 = slice(0, 0)
60 for j1, l1 in enumerate(l_j):
61 i1 = slice(i1.stop, i1.stop + 2 * l1 + 1)
62 i2 = slice(0, 0)
63 for j2, l2 in enumerate(l_j):
64 i2 = slice(i2.stop, i2.stop + 2 * l2 + 1)
65 if l1 != l2:
66 continue
67 N0 = N0_q[(j2 + j1 * nj - j1 * (j1 + 1) // 2
68 if j1 < j2 else
69 j1 + j2 * nj - j2 * (j2 + 1) // 2)]
71 smm_v += np.sum(M_vii[:, i1, i2], axis=(1, 2)) * N0
72 dHL_vii[:, i1, i2] += np.eye(2 * l1 + 1) * N0
74 for v in range(3):
75 dHL_vii[v] *= (1 - u_v[v]**2) * smm_v[v] - u_v[v] * (
76 u_v[(v + 1) % 3] * smm_v[(v + 1) % 3]
77 + u_v[(v + 2) % 3] * smm_v[(v + 2) % 3])
78 dHL_vii *= 2 * self.penalty
80 if not return_energy:
81 return 0.0, dHL_vii
82 else:
83 return self.penalty * (smm_v @ smm_v - (u_v @ smm_v)**2), dHL_vii