Source code for ase.utils

import errno
import functools
import os
import pickle
import sys
import time
import string
import warnings
from importlib import import_module
from math import sin, cos, radians, atan2, degrees
from contextlib import contextmanager
from math import gcd
from pathlib import PurePath, Path

import numpy as np

from ase.formula import formula_hill, formula_metal

__all__ = ['exec_', 'basestring', 'import_module', 'seterr', 'plural',
           'devnull', 'gcd', 'convert_string_to_fd', 'Lock',
           'opencew', 'OpenLock', 'rotate', 'irotate', 'pbc2pbc', 'givens',
           'hsv2rgb', 'hsv', 'pickleload', 'FileNotFoundError',
           'formula_hill', 'formula_metal', 'PurePath']


# Python 2+3 compatibility stuff (let's try to remove these things):
basestring = str
pickleload = functools.partial(pickle.load, encoding='bytes')


def deprecated(msg, category=FutureWarning):
    """Return a decorator deprecating a function.

    Use like @deprecated('warning message and explanation')."""
    def deprecated_decorator(func):
        @functools.wraps(func)
        def deprecated_function(*args, **kwargs):
            warning = msg
            if not isinstance(warning, Warning):
                warning = category(warning)
            warnings.warn(warning)
            return func(*args, **kwargs)
        return deprecated_function
    return deprecated_decorator


[docs]@contextmanager def seterr(**kwargs): """Set how floating-point errors are handled. See np.seterr() for more details. """ old = np.seterr(**kwargs) try: yield finally: np.seterr(**old)
[docs]def plural(n, word): """Use plural for n!=1. >>> plural(0, 'egg'), plural(1, 'egg'), plural(2, 'egg') ('0 eggs', '1 egg', '2 eggs') """ if n == 1: return '1 ' + word return '%d %ss' % (n, word)
class DevNull: encoding = 'UTF-8' closed = False _use_os_devnull = deprecated('use open(os.devnull) instead', DeprecationWarning) # Deprecated for ase-3.21.0. Change to futurewarning later on. @_use_os_devnull def write(self, string): pass @_use_os_devnull def flush(self): pass @_use_os_devnull def seek(self, offset, whence=0): return 0 @_use_os_devnull def tell(self): return 0 @_use_os_devnull def close(self): pass @_use_os_devnull def isatty(self): return False @_use_os_devnull def read(self, n=-1): return '' devnull = DevNull()
[docs]def convert_string_to_fd(name, world=None): """Create a file-descriptor for text output. Will open a file for writing with given name. Use None for no output and '-' for sys.stdout. """ if world is None: from ase.parallel import world if name is None or world.rank != 0: return open(os.devnull, 'w') if name == '-': return sys.stdout if isinstance(name, (str, PurePath)): return open(str(name), 'w') # str for py3.5 pathlib return name # we assume name is already a file-descriptor
# Only Windows has O_BINARY: CEW_FLAGS = os.O_CREAT | os.O_EXCL | os.O_WRONLY | getattr(os, 'O_BINARY', 0) @contextmanager def xwopen(filename, world=None): """Create and open filename exclusively for writing. If master cpu gets exclusive write access to filename, a file descriptor is returned (a dummy file descriptor is returned on the slaves). If the master cpu does not get write access, None is returned on all processors.""" fd = opencew(filename, world) try: yield fd finally: if fd is not None: fd.close() #@deprecated('use "with xwopen(...) as fd: ..." to prevent resource leak')
[docs]def opencew(filename, world=None): return _opencew(filename, world)
def _opencew(filename, world=None): if world is None: from ase.parallel import world closelater = [] def opener(file, flags): return os.open(file, flags | CEW_FLAGS) try: error = 0 if world.rank == 0: try: fd = open(filename, 'wb', opener=opener) except OSError as ex: error = ex.errno else: closelater.append(fd) else: fd = open(os.devnull, 'wb') closelater.append(fd) # Synchronize: error = world.sum(error) if error == errno.EEXIST: return None if error: raise OSError(error, 'Error', filename) return fd except BaseException: for fd in closelater: fd.close() raise class Lock: def __init__(self, name='lock', world=None, timeout=float('inf')): self.name = str(name) self.timeout = timeout if world is None: from ase.parallel import world self.world = world def acquire(self): dt = 0.2 t1 = time.time() while True: fd = opencew(self.name, self.world) if fd is not None: self.fd = fd break time_left = self.timeout - (time.time() - t1) if time_left <= 0: raise TimeoutError time.sleep(min(dt, time_left)) dt *= 2 def release(self): self.world.barrier() # Important to close fd before deleting file on windows # as a WinError would otherwise be raised. self.fd.close() if self.world.rank == 0: os.remove(self.name) self.world.barrier() def __enter__(self): self.acquire() def __exit__(self, type, value, tb): self.release() class OpenLock: def acquire(self): pass def release(self): pass def __enter__(self): pass def __exit__(self, type, value, tb): pass def search_current_git_hash(arg, world=None): """Search for .git directory and current git commit hash. Parameters: arg: str (directory path) or python module .git directory is searched from the parent directory of the given directory or module. """ if world is None: from ase.parallel import world if world.rank != 0: return None # Check argument if isinstance(arg, str): # Directory path dpath = arg else: # Assume arg is module dpath = os.path.dirname(arg.__file__) # dpath = os.path.abspath(dpath) # in case this is just symlinked into $PYTHONPATH dpath = os.path.realpath(dpath) dpath = os.path.dirname(dpath) # Go to the parent directory git_dpath = os.path.join(dpath, '.git') if not os.path.isdir(git_dpath): # Replace this 'if' with a loop if you want to check # further parent directories return None HEAD_file = os.path.join(git_dpath, 'HEAD') if not os.path.isfile(HEAD_file): return None with open(HEAD_file, 'r') as f: line = f.readline().strip() if line.startswith('ref: '): ref = line[5:] ref_file = os.path.join(git_dpath, ref) else: # Assuming detached HEAD state ref_file = HEAD_file if not os.path.isfile(ref_file): return None with open(ref_file, 'r') as f: line = f.readline().strip() if all(c in string.hexdigits for c in line): return line return None def rotate(rotations, rotation=np.identity(3)): """Convert string of format '50x,-10y,120z' to a rotation matrix. Note that the order of rotation matters, i.e. '50x,40z' is different from '40z,50x'. """ if rotations == '': return rotation.copy() for i, a in [('xyz'.index(s[-1]), radians(float(s[:-1]))) for s in rotations.split(',')]: s = sin(a) c = cos(a) if i == 0: rotation = np.dot(rotation, [(1, 0, 0), (0, c, s), (0, -s, c)]) elif i == 1: rotation = np.dot(rotation, [(c, 0, -s), (0, 1, 0), (s, 0, c)]) else: rotation = np.dot(rotation, [(c, s, 0), (-s, c, 0), (0, 0, 1)]) return rotation def givens(a, b): """Solve the equation system:: [ c s] [a] [r] [ ] . [ ] = [ ] [-s c] [b] [0] """ sgn = np.sign if b == 0: c = sgn(a) s = 0 r = abs(a) elif abs(b) >= abs(a): cot = a / b u = sgn(b) * (1 + cot**2)**0.5 s = 1. / u c = s * cot r = b * u else: tan = b / a u = sgn(a) * (1 + tan**2)**0.5 c = 1. / u s = c * tan r = a * u return c, s, r def irotate(rotation, initial=np.identity(3)): """Determine x, y, z rotation angles from rotation matrix.""" a = np.dot(initial, rotation) cx, sx, rx = givens(a[2, 2], a[1, 2]) cy, sy, ry = givens(rx, a[0, 2]) cz, sz, rz = givens(cx * a[1, 1] - sx * a[2, 1], cy * a[0, 1] - sy * (sx * a[1, 1] + cx * a[2, 1])) x = degrees(atan2(sx, cx)) y = degrees(atan2(-sy, cy)) z = degrees(atan2(sz, cz)) return x, y, z def pbc2pbc(pbc): newpbc = np.empty(3, bool) newpbc[:] = pbc return newpbc def hsv2rgb(h, s, v): """http://en.wikipedia.org/wiki/HSL_and_HSV h (hue) in [0, 360[ s (saturation) in [0, 1] v (value) in [0, 1] return rgb in range [0, 1] """ if v == 0: return 0, 0, 0 if s == 0: return v, v, v i, f = divmod(h / 60., 1) p = v * (1 - s) q = v * (1 - s * f) t = v * (1 - s * (1 - f)) if i == 0: return v, t, p elif i == 1: return q, v, p elif i == 2: return p, v, t elif i == 3: return p, q, v elif i == 4: return t, p, v elif i == 5: return v, p, q else: raise RuntimeError('h must be in [0, 360]') def hsv(array, s=.9, v=.9): array = (array + array.min()) * 359. / (array.max() - array.min()) result = np.empty((len(array.flat), 3)) for rgb, h in zip(result, array.flat): rgb[:] = hsv2rgb(h, s, v) return np.reshape(result, array.shape + (3,)) # This code does the same, but requires pylab # def cmap(array, name='hsv'): # import pylab # a = (array + array.min()) / array.ptp() # rgba = getattr(pylab.cm, name)(a) # return rgba[:-1] # return rgb only (not alpha) def longsum(x): """128-bit floating point sum.""" return float(np.asarray(x, dtype=np.longdouble).sum())
[docs]@contextmanager def workdir(path, mkdir=False): """Temporarily change, and optionally create, working directory.""" path = Path(path) if mkdir: path.mkdir(parents=True, exist_ok=True) olddir = os.getcwd() os.chdir(str(path)) # py3.6 allows chdir(path) but we still need 3.5 try: yield # Yield the Path or dirname maybe? finally: os.chdir(olddir)
class iofunction: """Decorate func so it accepts either str or file. (Won't work on functions that return a generator.)""" def __init__(self, mode): self.mode = mode def __call__(self, func): @functools.wraps(func) def iofunc(file, *args, **kwargs): openandclose = isinstance(file, (str, PurePath)) fd = None try: if openandclose: fd = open(str(file), self.mode) else: fd = file obj = func(fd, *args, **kwargs) return obj finally: if openandclose and fd is not None: # fd may be None if open() failed fd.close() return iofunc def writer(func): return iofunction('w')(func) def reader(func): return iofunction('r')(func) # The next two functions are for hotplugging into a JSONable class # using the jsonable decorator. We are supposed to have this kind of stuff # in ase.io.jsonio, but we'd rather import them from a 'basic' module # like ase/utils than one which triggers a lot of extra (cyclic) imports. def write_json(self, fd): """Write to JSON file.""" from ase.io.jsonio import write_json as _write_json _write_json(fd, self) @classmethod # type: ignore def read_json(cls, fd): """Read new instance from JSON file.""" from ase.io.jsonio import read_json as _read_json obj = _read_json(fd) assert type(obj) is cls return obj def jsonable(name): """Decorator for facilitating JSON I/O with a class. Pokes JSON-based read and write functions into the class. In order to write an object to JSON, it needs to be a known simple type (such as ndarray, float, ...) or implement todict(). If the class defines a string called ase_objtype, the decoder will want to convert the object back into its original type when reading.""" def jsonableclass(cls): cls.ase_objtype = name if not hasattr(cls, 'todict'): raise TypeError('Class must implement todict()') # We may want the write and read to be optional. # E.g. a calculator might want to be JSONable, but not # that .write() produces a JSON file. # # This is mostly for 'lightweight' object IO. cls.write = write_json cls.read = read_json return cls return jsonableclass class ExperimentalFeatureWarning(Warning): pass def experimental(func): """Decorator for functions not ready for production use.""" @functools.wraps(func) def expfunc(*args, **kwargs): warnings.warn('This function may change or misbehave: {}()' .format(func.__qualname__), ExperimentalFeatureWarning) return func(*args, **kwargs) return expfunc def lazymethod(meth): """Decorator for lazy evaluation and caching of data. Example:: class MyClass: @lazymethod def thing(self): return expensive_calculation() The method body is only executed first time thing() is called, and its return value is stored. Subsequent calls return the cached value.""" name = meth.__name__ @functools.wraps(meth) def getter(self): try: cache = self._lazy_cache except AttributeError: cache = self._lazy_cache = {} if name not in cache: cache[name] = meth(self) return cache[name] return getter def atoms_to_spglib_cell(atoms): """Convert atoms into data suitable for calling spglib.""" return (atoms.get_cell(), atoms.get_scaled_positions(), atoms.get_atomic_numbers()) def warn_legacy(feature_name): warnings.warn( f'The {feature_name} feature is untested and ASE developers do not ' 'know whether it works or how to use it. Please rehabilitate it ' '(by writing unittests) or it may be removed.', FutureWarning) def lazyproperty(meth): """Decorator like lazymethod, but making item available as a property.""" return property(lazymethod(meth))