Coverage for gpaw/utilities/scalapack.py: 72%
268 statements
« prev ^ index » next coverage.py v7.7.1, created at 2025-07-14 00:18 +0000
« prev ^ index » next coverage.py v7.7.1, created at 2025-07-14 00:18 +0000
1# Copyright (C) 2003 CAMP
2# Copyright (C) 2010 Argonne National Laboratory
3# Please see the accompanying LICENSE file for further information.
5"""
6Python wrapper functions for the C and Fortran packages:
7Basic Linear Algebra Communication Subprogramcs (BLACS)
8ScaLAPACK
10See also:
11https://www.netlib.org/blacs
12and
13https://www.netlib.org/scalapack
14"""
15import numpy as np
17import gpaw.cgpaw as cgpaw
19switch_lu = {'U': 'L', 'L': 'U'}
20switch_lr = {'L': 'R', 'R': 'L'}
23def scalapack_tri2full(desc, array, conj=True):
24 """Write lower triangular part into upper triangular part of matrix.
26 If conj == True, the lower triangular part is the complex conjugate
27 of the upper triangular part.
29 This function is a frightful hack, but we can improve the
30 implementation later."""
32 # Zero upper triangle:
33 scalapack_zero(desc, array, 'U')
34 buf = array.copy()
35 # Set diagonal to zero in the copy:
36 scalapack_set(desc, buf, alpha=0.0, beta=0.0, uplo='U')
37 # Now transpose tmp_mm adding the result to the original matrix:
38 pblas_tran(alpha=1.0, a_MN=buf,
39 beta=1.0, c_NM=array,
40 desca=desc, descc=desc,
41 conj=conj)
44def scalapack_zero(desca, a, uplo, ia=1, ja=1):
45 """Zero the upper or lower half of a square matrix."""
46 assert desca.gshape[0] == desca.gshape[1]
47 p = desca.gshape[0] - 1
48 if uplo == 'L':
49 ia = ia + 1
50 else:
51 ja = ja + 1
52 scalapack_set(desca, a, 0.0, 0.0, uplo, p, p, ia, ja)
55def scalapack_set(desca, a, alpha, beta, uplo, m=None, n=None, ia=1, ja=1):
56 """Set the diagonal and upper/lower triangular part of a.
58 Set the upper or lower triangular part of a to alpha, and the diagonal
59 of a to beta, where alpha and beta are real or complex numbers."""
60 desca.checkassert(a)
61 assert uplo in ['L', 'U']
62 if m is None:
63 m = desca.gshape[0]
64 if n is None:
65 n = desca.gshape[1]
66 if not desca.blacsgrid.is_active():
67 return
68 cgpaw.scalapack_set(a, desca.asarray(), alpha, beta,
69 switch_lu[uplo], n, m, ja, ia)
72def scalapack_diagonalize_dc(desca, a, z, w, uplo):
73 """Diagonalize symmetric matrix using the divide & conquer algorithm.
74 Orthogonal eigenvectors not guaranteed; no warning is provided.
76 Solve the eigenvalue equation::
78 A_nn Z_nn = w_N Z_nn
80 Diagonalizes A_nn and writes eigenvectors to Z_nn. Both A_nn
81 and Z_nn must be compatible with desca descriptor. Values in
82 A_nn will be overwritten.
84 Eigenvalues are written to the global array w_N in ascending order.
86 The `uplo` flag can be either 'L' or 'U', meaning that the
87 matrices are taken to be upper or lower triangular respectively.
88 """
89 desca.checkassert(a)
90 desca.checkassert(z)
91 # only symmetric matrices
92 assert desca.gshape[0] == desca.gshape[1]
93 assert uplo in ['L', 'U']
94 if not desca.blacsgrid.is_active():
95 return
96 assert desca.gshape[0] == len(w)
97 info = cgpaw.scalapack_diagonalize_dc(a, desca.asarray(),
98 switch_lu[uplo], z, w)
99 if info != 0:
100 raise RuntimeError('scalapack_diagonalize_dc error: %d' % info)
103def scalapack_diagonalize_ex(desca, a, z, w, uplo, iu=None):
104 """Diagonalize symmetric matrix using the bisection and inverse
105 iteration algorithm. Re-orthogonalization of eigenvectors
106 is an issue for tightly clustered eigenvalue problems; it
107 requires substantial memory and is not scalable. See ScaLAPACK
108 pdsyevx.f routine for more information.
110 Solve the eigenvalue equation::
112 A_nn Z_nn = w_N Z_nn
114 Diagonalizes A_nn and writes eigenvectors to Z_nn. Both A_nn
115 and Z_nn must be compatible with desca descriptor. Values in
116 A_nn will be overwritten.
118 Eigenvalues are written to the global array w_N in ascending order.
120 The `uplo` flag can be either 'L' or 'U', meaning that the
121 matrices are taken to be upper or lower triangular respectively.
123 The `iu` specifies how many eigenvectors and eigenvalues to compute.
124 """
125 desca.checkassert(a)
126 desca.checkassert(z)
127 # only symmetric matrices
128 assert desca.gshape[0] == desca.gshape[1]
129 if iu is None: # calculate all eigenvectors and eigenvalues
130 iu = desca.gshape[0]
131 assert 1 < iu <= desca.gshape[0]
132 # still need assert for eigenvalues
133 assert uplo in ['L', 'U']
134 if not desca.blacsgrid.is_active():
135 return
136 assert desca.gshape[0] == len(w)
137 info = cgpaw.scalapack_diagonalize_ex(a, desca.asarray(),
138 switch_lu[uplo],
139 iu, z, w)
140 if info != 0:
141 # 0 means you are OK
142 raise RuntimeError('scalapack_diagonalize_ex error: %d' % info)
145def scalapack_diagonalize_mr3(desca, a, z, w, uplo, iu=None):
146 """Diagonalize symmetric matrix using the MRRR algorithm.
148 Solve the eigenvalue equation::
150 A_nn Z_nn = w_N Z_nn
152 Diagonalizes A_nn and writes eigenvectors to Z_nn. Both A_nn
153 and Z_nn must be compatible with this desca descriptor. Values in
154 A_nn will be overwritten.
156 Eigenvalues are written to the global array w_N in ascending order.
158 The `uplo` flag can be either 'L' or 'U', meaning that the
159 matrices are taken to be upper or lower triangular respectively.
161 The `iu` specifies how many eigenvectors and eigenvalues to compute.
162 """
163 desca.checkassert(a)
164 desca.checkassert(z)
165 # only symmetric matrices
166 assert desca.gshape[0] == desca.gshape[1]
167 if iu is None: # calculate all eigenvectors and eigenvalues
168 iu = desca.gshape[0]
169 assert 1 < iu <= desca.gshape[0]
170 # stil need assert for eigenvalues
171 assert uplo in ['L', 'U']
172 if not desca.blacsgrid.is_active():
173 return
174 assert desca.gshape[0] == len(w)
175 info = cgpaw.scalapack_diagonalize_mr3(a, desca.asarray(),
176 switch_lu[uplo],
177 iu, z, w)
178 if info != 0:
179 raise RuntimeError('scalapack_diagonalize_mr3 error: %d' % info)
182def scalapack_general_diagonalize_dc(desca, a, b, z, w, uplo):
183 """Diagonalize symmetric matrix using the divide & conquer algorithm.
184 Orthogonal eigenvectors not guaranteed; no warning is provided.
186 Solve the generalized eigenvalue equation::
188 A_nn Z_nn = w_N B_nn Z_nn
190 B_nn is assumed to be positivde definite. Eigenvectors written to Z_nn.
191 Both A_nn, B_nn and Z_nn must be compatible with desca descriptor.
192 Values in A_nn and B_nn will be overwritten.
194 Eigenvalues are written to the global array w_N in ascending order.
196 The `uplo` flag can be either 'L' or 'U', meaning that the
197 matrices are taken to be upper or lower triangular respectively.
198 """
199 desca.checkassert(a)
200 desca.checkassert(b)
201 desca.checkassert(z)
202 # only symmetric matrices
203 assert desca.gshape[0] == desca.gshape[1]
204 assert uplo in ['L', 'U']
205 if not desca.blacsgrid.is_active():
206 return
207 assert desca.gshape[0] == len(w)
208 info = cgpaw.scalapack_general_diagonalize_dc(a, desca.asarray(),
209 switch_lu[uplo], b, z, w)
210 if info != 0:
211 raise RuntimeError('scalapack_general_diagonalize_dc error: %d' % info)
214def scalapack_general_diagonalize_ex(desca, a, b, z, w, uplo, iu=None):
215 """Diagonalize symmetric matrix using the bisection and inverse
216 iteration algorithm. Re-orthogonalization of eigenvectors
217 is an issue for tightly clustered eigenvalue problems; it
218 requires substantial memory and is not scalable. See ScaLAPACK
219 pdsyevx.f routine for more information.
221 Solves the eigenvalue equation::
223 A_nn Z_nn = w_N B_nn Z_nn
225 B_nn is assumed to be positivde definite. Eigenvectors written to Z_nn.
226 Both A_nn, B_nn and Z_nn must be compatible with desca descriptor.
227 Values in A_nn and B_nn will be overwritten.
229 Eigenvalues are written to the global array w_N in ascending order.
231 The `uplo` flag can be either 'L' or 'U', meaning that the
232 matrices are taken to be upper or lower triangular respectively.
234 The `iu` specifies how many eigenvectors and eigenvalues to compute.
235 """
236 desca.checkassert(a)
237 desca.checkassert(b)
238 desca.checkassert(z)
239 # only symmetric matrices
240 assert desca.gshape[0] == desca.gshape[1]
241 if iu is None: # calculate all eigenvectors and eigenvalues
242 iu = desca.gshape[0]
243 assert 1 < iu <= desca.gshape[0]
244 # still need assert for eigenvalues
245 assert uplo in ['L', 'U']
246 if not desca.blacsgrid.is_active():
247 return
248 assert desca.gshape[0] == len(w)
249 info = cgpaw.scalapack_general_diagonalize_ex(a, desca.asarray(),
250 switch_lu[uplo],
251 iu, b, z, w)
252 if info != 0:
253 # 0 means you are OK
254 raise RuntimeError('scalapack_general_diagonalize_ex error: %d' % info)
257def scalapack_general_diagonalize_mr3(desca, a, b, z, w, uplo, iu=None):
258 """Diagonalize symmetric matrix using the MRRR algorithm.
260 Solve the generalized eigenvalue equation::
262 A_nn Z_nn = w_N B_nn Z_nn
264 B_nn is assumed to be positivde definite. Eigenvectors written to Z_nn.
265 Both A_nn, B_nn and Z_nn must be compatible with desca descriptor.
266 Values in A_nn and B_nn will be overwritten.
268 Eigenvalues are written to the global array w_N in ascending order.
270 The `uplo` flag can be either 'L' or 'U', meaning that the
271 matrices are taken to be upper or lower triangular respectively.
273 The `iu` specifies how many eigenvectors and eigenvalues to compute.
274 """
275 desca.checkassert(a)
276 desca.checkassert(b)
277 desca.checkassert(z)
278 # only symmetric matrices
279 assert desca.gshape[0] == desca.gshape[1]
280 if iu is None: # calculate all eigenvectors and eigenvalues
281 iu = desca.gshape[0]
282 assert 1 < iu <= desca.gshape[0]
283 # still need assert for eigenvalues
284 assert uplo in ['L', 'U']
285 if not desca.blacsgrid.is_active():
286 return
287 assert desca.gshape[0] == len(w)
288 info = cgpaw.scalapack_general_diagonalize_mr3(a, desca.asarray(),
289 switch_lu[uplo],
290 iu, b, z, w)
291 if info != 0:
292 raise RuntimeError('scalapack_general_diagonalize_mr3 error: %d' %
293 info)
296def have_mkl():
297 return hasattr(cgpaw, 'mklscalapack_diagonalize_geev')
300def mkl_scalapack_diagonalize_non_symmetric(desca, a, z, w, transpose=True):
301 """ Diagonalize non symmetric matrix.
303 Requires mkl scalapack to function.
304 Transpose is true by default (in order to match Fortran array ordering)
305 Disable this if you want more control and reduced overhead.
306 """
307 desca.checkassert(a)
308 desca.checkassert(z)
310 assert desca.gshape[0] == desca.gshape[1]
311 assert all([bsize >= 6 for bsize in desca.bshape]), \
312 'Block size must be >= 6'
314 if not desca.blacsgrid.is_active():
315 return
317 if transpose:
318 a2 = desca.empty(dtype=complex)
319 pblas_tran(1, a, 0, a2, desca, desca, conj=False)
320 info = cgpaw.mklscalapack_diagonalize_geev(a2, z, w, desca.asarray())
321 if transpose:
322 z2 = desca.empty(dtype=complex)
323 pblas_tran(1, z, 0, z2, desca, desca, conj=False)
324 z[:] = z2
326 if info != 0:
327 raise RuntimeError('mkl_non_symmetric_diagonalize_geevx error: %d'
328 % info)
331def scalapack_inverse_cholesky(desca, a, uplo):
332 """Perform Cholesky decomposin followed by an inversion
333 of the resulting triangular matrix.
335 Only the upper or lower half of the matrix a will be
336 modified; the other half is zeroed out.
338 The `uplo` flag can be either 'L' or 'U', meaning that the
339 matrices are taken to be upper or lower triangular respectively.
340 """
341 desca.checkassert(a)
342 # only symmetric matrices
343 assert desca.gshape[0] == desca.gshape[1]
344 assert uplo in ['L', 'U']
345 if not desca.blacsgrid.is_active():
346 return
347 info = cgpaw.scalapack_inverse_cholesky(a, desca.asarray(),
348 switch_lu[uplo])
349 if info != 0:
350 raise RuntimeError('scalapack_inverse_cholesky error: %d' % info)
353def scalapack_inverse(desca, a, uplo):
354 """Perform a hermitian matrix inversion.
356 """
357 desca.checkassert(a)
358 # only symmetric matrices
359 assert desca.gshape[0] == desca.gshape[1]
360 assert uplo in ['L', 'U']
361 if not desca.blacsgrid.is_active():
362 return
363 info = cgpaw.scalapack_inverse(a, desca.asarray(), switch_lu[uplo])
364 if info != 0:
365 raise RuntimeError('scalapack_inverse error: %d' % info)
368def scalapack_solve(desca, descb, a, b):
369 """General matrix solve.
371 Solve X from A*X = B. The array b will be replaced with the result.
373 This function works on the transposed form. The equivalent
374 non-distributed operation is numpy.linalg.solve(a.T, b.T).T.
376 This function executes the following scalapack routine:
377 * pzgesv if matrices are complex
378 * pdgesv if matrices are real
379 """
380 desca.checkassert(a)
381 descb.checkassert(b)
382 assert desca.gshape[0] == desca.gshape[1], 'A not a square matrix'
383 assert desca.bshape[0] == desca.bshape[1], 'A not having square blocks'
384 assert desca.gshape[1] == descb.gshape[1], 'B shape not compatible with A'
385 assert desca.bshape[1] == descb.bshape[1], 'B blocks not compatible with A'
387 if not desca.blacsgrid.is_active():
388 return
389 info = cgpaw.scalapack_solve(a, desca.asarray(), b, descb.asarray())
390 if info != 0:
391 raise RuntimeError('scalapack_solve error: %d' % info)
394def pblas_tran(alpha, a_MN, beta, c_NM, desca, descc, conj=True):
395 """Matrix transpose.
397 C <- alpha*A.H + beta*C if conj == True
398 C <- alpha*A.T + beta*C if conj == False
400 This function executes the following PBLAS routine:
401 * pztranc if matrices are complex and conj == True
402 * pztranu if matrices are complex and conj == False
403 * pdtran if matrices are real
404 """
405 desca.checkassert(a_MN)
406 descc.checkassert(c_NM)
407 M, N = desca.gshape
408 assert N, M == descc.gshape
409 cgpaw.pblas_tran(N, M, alpha, a_MN, beta, c_NM,
410 desca.asarray(), descc.asarray(),
411 conj)
414def _pblas_hemm_symm(alpha, a_MM, b_MN, beta, c_MN, desca, descb, descc,
415 side, uplo, hemm):
416 """Hermitian or symmetric matrix-matrix product.
418 Do not call this function directly but
419 use :func:`pblas_hemm` or :func:`pblas_symm` instead.
421 C <- alpha*A*B + beta*C if side == 'L'
422 C <- alpha*B*A + beta*C if side == 'R'
424 Only lower or upper diagonal of a_MM is used.
426 This function executes the following PBLAS routine:
427 * pzhemm if matrices are complex and hemm == True
428 * pzsymm if matrices are complex and hemm == False
429 * pdsymm if matrices are real
430 """
431 # Note: if side == 'R', then a_MM matrix is actually size of a_NN
432 desca.checkassert(a_MM)
433 descb.checkassert(b_MN)
434 descc.checkassert(c_MN)
435 assert side in ['L', 'R'] and uplo in ['L', 'U']
436 Ma, Ma2 = desca.gshape
437 assert Ma == Ma2, 'A not square matrix'
438 Mb, Nb = descb.gshape
439 if side == 'L':
440 assert Mb == Ma
441 else:
442 assert Nb == Ma
443 M, N = descc.gshape
444 assert M == Mb
445 assert N == Nb
447 if not desca.blacsgrid.is_active():
448 return
449 cgpaw.pblas_hemm_symm(switch_lr[side], switch_lu[uplo],
450 N, M, alpha, a_MM, b_MN, beta, c_MN,
451 desca.asarray(), descb.asarray(), descc.asarray(),
452 hemm)
455def pblas_hemm(alpha, a_MM, b_MN, beta, c_MN, desca, descb, descc,
456 side='L', uplo='L'):
457 """Hermitian matrix-matrix product.
459 C <- alpha*A*B + beta*C if side == 'L'
460 C <- alpha*B*A + beta*C if side == 'R'
462 Only lower or upper diagonal of a_MM is used.
464 This function executes the following PBLAS routine:
465 * pzhemm if matrices are complex
466 * pdsymm if matrices are real
467 """
468 return _pblas_hemm_symm(alpha, a_MM, b_MN, beta, c_MN,
469 desca, descb, descc,
470 side, uplo, hemm=True)
473def pblas_symm(alpha, a_MM, b_MN, beta, c_MN, desca, descb, descc,
474 side='L', uplo='L'):
475 """Symmetric matrix-matrix product.
477 C <- alpha*A*B + beta*C if side == 'L'
478 C <- alpha*B*A + beta*C if side == 'R'
480 Only lower or upper diagonal of a_MM is used.
482 This function executes the following PBLAS routine:
483 * pzsymm if matrices are complex
484 * pdsymm if matrices are real
485 """
486 return _pblas_hemm_symm(alpha, a_MM, b_MN, beta, c_MN,
487 desca, descb, descc,
488 side, uplo, hemm=False)
491def pblas_gemm(alpha, a_MK, b_KN, beta, c_MN, desca, descb, descc,
492 transa='N', transb='N'):
493 """General matrix-matrix product.
495 C <- alpha*A*B + beta*C
497 This function executes the following PBLAS routine:
498 * pzgemm if matrices are complex
499 * pdgemm if matrices are real
500 """
501 desca.checkassert(a_MK)
502 descb.checkassert(b_KN)
503 descc.checkassert(c_MN)
504 assert transa in ['N', 'T', 'C'] and transb in ['N', 'T', 'C']
505 M, Ka = desca.gshape
506 Kb, N = descb.gshape
508 if transa in ['T', 'C']:
509 M, Ka = Ka, M
510 if transb in ['T', 'C']:
511 Kb, N = N, Kb
512 Mc, Nc = descc.gshape
514 assert Ka == Kb
515 assert M == Mc
516 assert N == Nc
518 if not desca.blacsgrid.is_active():
519 return
520 cgpaw.pblas_gemm(N, M, Ka, alpha, b_KN, a_MK, beta, c_MN,
521 descb.asarray(), desca.asarray(), descc.asarray(),
522 transb, transa)
525def pblas_simple_gemm(desca, descb, descc, a_MK, b_KN, c_MN,
526 transa='N', transb='N'):
527 alpha = 1.0
528 beta = 0.0
529 pblas_gemm(alpha, a_MK, b_KN, beta, c_MN, desca, descb, descc,
530 transa, transb)
533def pblas_simple_hemm(desca, descb, descc, a_MM, b_MN, c_MN,
534 side='L', uplo='L'):
535 alpha = 1.0
536 beta = 0.0
537 pblas_hemm(alpha, a_MM, b_MN, beta, c_MN, desca, descb, descc, side, uplo)
540def pblas_simple_symm(desca, descb, descc, a_MM, b_MN, c_MN,
541 side='L', uplo='L'):
542 alpha = 1.0
543 beta = 0.0
544 pblas_symm(alpha, a_MM, b_MN, beta, c_MN, desca, descb, descc, side, uplo)
547def pblas_gemv(alpha, a_MN, x_N, beta, y_M, desca, descx, descy,
548 transa='N'):
549 """General matrix-vector product.
551 y <- alpha*A*x + beta*y
553 This function executes the following PBLAS routine:
554 * pzgemv if matrices are complex
555 * pdgemv if matrices are real
556 """
557 desca.checkassert(a_MN)
558 descx.checkassert(x_N)
559 descy.checkassert(y_M)
560 assert transa in ['N', 'T', 'C']
561 M, N = desca.gshape
562 Nx, Ox = descx.gshape
563 My, Oy = descy.gshape
564 assert Ox == 1
565 assert Oy == 1
566 if transa == 'N':
567 assert Nx == N
568 assert My == M
569 else:
570 assert Nx == M
571 assert My == N
573 # Switch transposition and handle complex conjugation manually
574 if transa == 'C':
575 a_MN = np.ascontiguousarray(a_MN.conj())
576 switch_ntc = {'N': 'T', 'T': 'N', 'C': 'N'}
578 if not desca.blacsgrid.is_active():
579 return
580 cgpaw.pblas_gemv(N, M, alpha,
581 a_MN, x_N, beta, y_M,
582 desca.asarray(),
583 descx.asarray(),
584 descy.asarray(),
585 switch_ntc[transa])
588def pblas_simple_gemv(desca, descx, descy, a, x, y, transa='N'):
589 alpha = 1.0
590 beta = 0.0
591 pblas_gemv(alpha, a, x, beta, y, desca, descx, descy, transa)
594def pblas_r2k(alpha, a_NK, b_NK, beta, c_NN, desca, descb, descc,
595 uplo='U'):
596 if not desca.blacsgrid.is_active():
597 return
598 desca.checkassert(a_NK)
599 descb.checkassert(b_NK)
600 descc.checkassert(c_NN)
601 assert descc.gshape[0] == descc.gshape[1] # symmetric matrix
602 assert desca.gshape == descb.gshape # same shape
603 assert uplo in ['L', 'U']
604 N = descc.gshape[0] # order of C
605 # K must take into account implicit tranpose due to C ordering
606 K = desca.gshape[1] # number of columns of A and B
607 cgpaw.pblas_r2k(N, K, alpha, a_NK, b_NK, beta, c_NN,
608 desca.asarray(),
609 descb.asarray(),
610 descc.asarray(),
611 uplo)
614def pblas_simple_r2k(desca, descb, descc, a, b, c, uplo='U'):
615 alpha = 1.0
616 beta = 0.0
617 pblas_r2k(alpha, a, b, beta, c,
618 desca, descb, descc, uplo)
621def pblas_rk(alpha, a_NK, beta, c_NN, desca, descc,
622 uplo='U'):
623 if not desca.blacsgrid.is_active():
624 return
625 desca.checkassert(a_NK)
626 descc.checkassert(c_NN)
627 assert descc.gshape[0] == descc.gshape[1] # symmetrix matrix
628 assert uplo in ['L', 'U']
629 N = descc.gshape[0] # order of C
630 # K must take into account implicit tranpose due to C ordering
631 K = desca.gshape[1] # number of columns of A
632 cgpaw.pblas_rk(N, K, alpha, a_NK, beta, c_NN,
633 desca.asarray(),
634 descc.asarray(),
635 uplo)
638def pblas_simple_rk(desca, descc, a, c):
639 alpha = 1.0
640 beta = 0.0
641 pblas_rk(alpha, a, beta, c,
642 desca, descc)