Coverage for gpaw/new/constraints.py: 98%

40 statements  

« prev     ^ index     » next       coverage.py v7.7.1, created at 2025-07-20 00:19 +0000

1from __future__ import annotations 

2 

3import numpy as np 

4from ase.units import Ha 

5from gpaw.new.extensions import Extension 

6from gpaw.typing import Array1D, Array3D, Vector 

7 

8 

9class SpinDirectionConstraint(Extension): 

10 name = 'spin_direction_constraint' 

11 

12 def __init__(self, 

13 constraint: dict[int, Vector], 

14 penalty: float = 0.8): 

15 """Spin-direction constraint. 

16 

17 Parameters 

18 ========== 

19 constraint: 

20 Dictionary mapping atom numbers to directions. 

21 Example: ``{0: (0, 0, 1), 1: (1, 0, 0), ...}``. 

22 penalty: 

23 Strength of penalty term in eV. 

24 """ 

25 self.constraint = {a: np.array(u_v) / np.linalg.norm(u_v) 

26 for a, u_v in constraint.items()} 

27 self.penalty = penalty / Ha 

28 

29 def todict(self): 

30 return dict(constraint=dict((a, u_v.tolist()) 

31 for a, u_v in self.constraint.items()), 

32 penalty=self.penalty * Ha) 

33 

34 def update_non_local_hamiltonian(self, 

35 D_sii, 

36 setup, 

37 atom_index, 

38 dH_sii) -> float: 

39 eL, dHL_vii = self.calculate(D_sii[1:4].real, atom_index, 

40 setup.l_j, setup.N0_q) 

41 dH_sii[1:4] += dHL_vii 

42 return eL 

43 

44 def calculate(self, 

45 M_vii: Array3D, 

46 a: int, 

47 l_j: Array1D, 

48 N0_q: Array1D, 

49 return_energy: bool = False): 

50 dHL_vii = np.zeros_like(M_vii) 

51 

52 if a not in self.constraint: 

53 return 0.0, dHL_vii 

54 u_v = self.constraint[a] 

55 

56 smm_v = np.zeros(3) # Spin magnetic moment 

57 

58 nj = len(l_j) 

59 i1 = slice(0, 0) 

60 for j1, l1 in enumerate(l_j): 

61 i1 = slice(i1.stop, i1.stop + 2 * l1 + 1) 

62 i2 = slice(0, 0) 

63 for j2, l2 in enumerate(l_j): 

64 i2 = slice(i2.stop, i2.stop + 2 * l2 + 1) 

65 if l1 != l2: 

66 continue 

67 N0 = N0_q[(j2 + j1 * nj - j1 * (j1 + 1) // 2 

68 if j1 < j2 else 

69 j1 + j2 * nj - j2 * (j2 + 1) // 2)] 

70 

71 smm_v += np.sum(M_vii[:, i1, i2], axis=(1, 2)) * N0 

72 dHL_vii[:, i1, i2] += np.eye(2 * l1 + 1) * N0 

73 

74 for v in range(3): 

75 dHL_vii[v] *= (1 - u_v[v]**2) * smm_v[v] - u_v[v] * ( 

76 u_v[(v + 1) % 3] * smm_v[(v + 1) % 3] 

77 + u_v[(v + 2) % 3] * smm_v[(v + 2) % 3]) 

78 dHL_vii *= 2 * self.penalty 

79 

80 if not return_energy: 

81 return 0.0, dHL_vii 

82 else: 

83 return self.penalty * (smm_v @ smm_v - (u_v @ smm_v)**2), dHL_vii