Coverage for gpaw/gpu/cpupy/__init__.py: 89%
245 statements
« prev ^ index » next coverage.py v7.7.1, created at 2025-07-20 00:19 +0000
« prev ^ index » next coverage.py v7.7.1, created at 2025-07-20 00:19 +0000
1from __future__ import annotations
2from types import SimpleNamespace
4import numpy as np
6import gpaw.gpu.cpupy.cublas as cublas
7import gpaw.gpu.cpupy.fft as fft
8import gpaw.gpu.cpupy.linalg as linalg
9import gpaw.gpu.cpupy.random as random
11__version__ = 'fake'
13__all__ = ['linalg', 'cublas', 'fft', 'random', '__version__']
15FAKE_CUPY_WARNING = """
16 ----------------------------------------------------------
17| WARNING |
18| -------------------------------------------------------- |
19| GPU calculation requested, but calculations are run on |
20| CPUs with the `cupy` substitute `gpaw.gpu.cpupy`. |
21| This is most likely not the desired behavior, except for |
22| testing purposes. Please check if you have inadvertently |
23| set the environment variable `GPAW_CPUPY`, consult |
24| `gpaw info` for `cupy` availability, and reconfigure and |
25| recompile GPAW if necessary. |
26 ----------------------------------------------------------
27"""
29pi = np.pi
32def require(a, requirements=None):
33 return ndarray(np.require(a._data, requirements=requirements))
36def empty(*args, **kwargs) -> ndarray:
37 return ndarray(np.empty(*args, **kwargs))
40def empty_like(a):
41 return ndarray(np.empty_like(a._data))
44def zeros(*args, **kwargs):
45 return ndarray(np.zeros(*args, **kwargs))
48def ones(*args, **kwargs):
49 return ndarray(np.ones(*args, **kwargs))
52def asnumpy(a, out=None):
53 if out is None:
54 return a._data.copy()
55 out[:] = a._data
56 return out
59def asarray(a, dtype=None):
60 if isinstance(a, ndarray):
61 if a.dtype == dtype or dtype is None:
62 return a
63 else:
64 return ndarray(a._data.astype(dtype))
65 return ndarray(np.array(a, dtype=dtype))
68def array(a, dtype=None):
69 return ndarray(np.array(a, dtype))
72def array_split(a, *args, **kwargs):
73 return list(map(ndarray, np.array_split(a._data, *args, **kwargs)))
76def ascontiguousarray(a):
77 return ndarray(np.ascontiguousarray(a._data))
80def dot(a, b):
81 return ndarray(np.dot(a._data, b._data))
84def inner(a, b):
85 return ndarray(np.inner(a._data, b._data))
88def outer(a, b):
89 return ndarray(np.outer(a._data, b._data))
92def multiply(a, b, c):
93 np.multiply(a._data, b._data, c._data)
96def negative(a, b):
97 np.negative(a._data, b._data)
100def einsum(indices, *args, **kwargs):
101 for k in kwargs:
102 kwargs[k] = kwargs[k]._data
103 return ndarray(
104 np.einsum(
105 indices,
106 *(arg._data for arg in args),
107 **kwargs))
110def diag(a):
111 return ndarray(np.diag(a._data))
114def abs(a):
115 return ndarray(np.abs(a._data))
118def exp(a):
119 return ndarray(np.exp(a._data))
122def conjugate(a):
123 return ndarray(np.conjugate(a._data))
126def log(a):
127 return ndarray(np.log(a._data))
130def eye(n):
131 return ndarray(np.eye(n))
134def triu_indices(n, k=0, m=None):
135 i, j = np.triu_indices(n, k, m)
136 return ndarray(i), ndarray(j)
139def tri(n, k=0, dtype=float):
140 return ndarray(np.tri(n, k=k, dtype=dtype))
143def allclose(a, b, **kwargs):
144 return np.allclose(asarray(a)._data, asarray(b)._data, **kwargs)
147def moveaxis(a, source, destination):
148 return ndarray(np.moveaxis(a._data, source, destination))
151def vdot(a, b):
152 return np.vdot(a._data, b._data)
155def fuse():
156 return lambda func: func
159def isfinite(a):
160 return ndarray(np.isfinite(a._data))
163def isnan(a):
164 return ndarray(np.isnan(a._data))
167class ndarray:
168 def __init__(self, data):
169 if isinstance(data, (float, complex, int, np.int32, np.int64,
170 np.bool_, np.float64, np.float32,
171 np.complex64, np.complex128)):
172 data = np.asarray(data)
173 assert isinstance(data, np.ndarray), type(data)
174 self._data = data
175 self.dtype = data.dtype
176 self.size = data.size
177 self.flags = data.flags
178 self.ndim = data.ndim
179 self.nbytes = data.nbytes
180 self.data = SimpleNamespace(ptr=data.ctypes.data)
182 @property
183 def shape(self):
184 return self._data.shape
186 @property
187 def T(self):
188 return ndarray(self._data.T)
190 @property
191 def real(self):
192 return ndarray(self._data.real)
194 @property
195 def imag(self):
196 return ndarray(self._data.imag)
198 @imag.setter
199 def imag(self, value):
200 if isinstance(value, (float, complex)):
201 self._data.imag = value
202 else:
203 self._data.imag = value._data
205 def set(self, a):
206 if self.ndim == 0:
207 self._data.fill(a)
208 else:
209 self._data[:] = a
211 def get(self):
212 return self._data.copy()
214 def copy(self):
215 return ndarray(self._data.copy())
217 def astype(self, dtype):
218 return ndarray(self._data.astype(dtype))
220 def all(self):
221 return ndarray(self._data.all())
223 def sum(self, out=None, **kwargs):
224 if out is not None:
225 out = out._data
226 return ndarray(self._data.sum(out=out, **kwargs))
228 def __repr__(self):
229 return 'cp.' + np.array_repr(self._data)
231 def __len__(self):
232 return len(self._data)
234 def __bool__(self):
235 return bool(self._data)
237 def __float__(self):
238 return self._data.__float__()
240 def __iter__(self):
241 for data in self._data:
242 if data.ndim == 0:
243 yield ndarray(data.item())
244 else:
245 yield ndarray(data)
247 def mean(self):
248 return ndarray(self._data.mean())
250 def __setitem__(self, index, value):
251 if isinstance(index, tuple):
252 def convert(a):
253 return a._data if isinstance(a, ndarray) else a
254 index = tuple([convert(a) for a in index])
255 if isinstance(index, ndarray):
256 index = index._data
257 if isinstance(value, ndarray):
258 self._data[index] = value._data
259 else:
260 assert isinstance(value, (float, int, complex))
261 self._data[index] = value
263 def __getitem__(self, index):
264 if isinstance(index, tuple):
265 def convert(a):
266 return a._data if isinstance(a, ndarray) else a
267 index = tuple([convert(a) for a in index])
268 if isinstance(index, ndarray):
269 index = index._data
270 return ndarray(self._data[index])
272 def __eq__(self, other):
273 if isinstance(other, (float, complex, int)):
274 return self._data == other
275 return ndarray(self._data == other._data)
277 def __ne__(self, other):
278 if isinstance(other, (float, complex, int)):
279 return self._data != other
280 return ndarray(self._data != other._data)
282 def __lt__(self, other):
283 if isinstance(other, (float, complex, int)):
284 return self._data < other
285 return ndarray(self._data < other._data)
287 def __le__(self, other):
288 if isinstance(other, (float, complex, int)):
289 return self._data <= other
290 return ndarray(self._data <= other._data)
292 def __gt__(self, other):
293 if isinstance(other, (float, complex, int)):
294 return self._data > other
295 return ndarray(self._data > other._data)
297 def __ge__(self, other):
298 if isinstance(other, (float, complex, int)):
299 return self._data >= other
300 return ndarray(self._data >= other._data)
302 def __neg__(self):
303 return ndarray(-self._data)
305 def __mul__(self, f):
306 if isinstance(f, (float, complex)):
307 return ndarray(f * self._data)
308 return ndarray(f._data * self._data)
310 def __rmul__(self, f):
311 return ndarray(f * self._data)
313 def __imul__(self, f):
314 if isinstance(f, (float, complex, int)):
315 self._data *= f
316 else:
317 self._data *= f._data
318 return self
320 def __truediv__(self, other):
321 if isinstance(other, (float, complex, int)):
322 return ndarray(self._data / other)
323 return ndarray(self._data / other._data)
325 def __pow__(self, i: int):
326 return ndarray(self._data**i)
328 def __add__(self, f):
329 if isinstance(f, (float, int, complex)):
330 return ndarray(f + self._data)
331 return ndarray(f._data + self._data)
333 def __sub__(self, f):
334 if isinstance(f, float):
335 return ndarray(self._data - f)
336 return ndarray(self._data - f._data)
338 def __rsub__(self, f):
339 return ndarray(f - self._data)
341 def __radd__(self, f):
342 return ndarray(f + self._data)
344 def __rtruediv__(self, f):
345 return ndarray(f / self._data)
347 def __iadd__(self, other):
348 if isinstance(other, float):
349 self._data += other
350 else:
351 self._data += other._data
352 return self
354 def __isub__(self, other):
355 if isinstance(other, float):
356 self._data -= other
357 else:
358 self._data -= other._data
359 return self
361 def __matmul__(self, other):
362 return ndarray(self._data @ other._data)
364 def ravel(self):
365 return ndarray(self._data.ravel())
367 def conj(self):
368 return ndarray(self._data.conj())
370 def reshape(self, shape):
371 return ndarray(self._data.reshape(shape))
373 def view(self, dtype):
374 return ndarray(self._data.view(dtype))
376 def item(self):
377 return self._data.item()
379 def trace(self, offset, axis1, axis2):
380 return ndarray(self._data.trace(offset, axis1, axis2))
382 def fill(self, val):
383 self._data.fill(val)
385 def any(self):
386 return ndarray(self._data.any())