# Copyright (c) 2023, The University of Texas at Austin
# & Georgia Institute of Technology
#
# All Rights reserved.
# See file COPYRIGHT for details.
#
# This file is part of the SOUPy package. For more information see
# https://github.com/hippylib/soupy/
#
# SOUPy is free software; you can redistribute it and/or modify it under the
# terms of the GNU General Public License (as published by the Free
# Software Foundation) version 3.0 dated June 2007.
import dolfin as dl
import numpy as np
from mpi4py import MPI
[docs]class NullCollective:
"""
No-overhead "Parallel" reduction utilities when a serial system of PDEs is solved on 1 process.
"""
[docs] def bcast(self, v, root=0):
return v
[docs] def size(self):
return 1
[docs] def rank(self):
return 0
[docs] def allReduce(self, v, op):
if op.lower() not in ["sum", "avg"]:
err_msg = "Unknown operation *{0}* in NullCollective.allReduce".format(op)
raise NotImplementedError(err_msg)
return v
[docs]class MultipleSamePartitioningPDEsCollective:
"""
Parallel reduction utilities when several serial systems of PDEs (one per process) are solved concurrently.
"""
def __init__(self, comm, is_serial_check=False):
"""
Constructor:
:param comm: MPI communicator
:type comm : :py:class:`mpi4py.MPI.Comm`
"""
self.comm = comm
self.is_serial_check = is_serial_check
[docs] def size(self):
return self.comm.Get_size()
[docs] def rank(self):
return self.comm.Get_rank()
def _allReduce_array(self,v,op):
err_msg = "Unknown operation *{0}* in MultipleSerialPDEsCollective.allReduce".format(op)
receive = np.zeros_like(v)
self.comm.Allreduce(v, receive, op = MPI.SUM)
if op == "sum":
v[:] = receive
elif op == "avg":
v[:] = (1./float(self.size()))*receive
else:
raise NotImplementedError(err_msg)
return v
[docs] def allReduce(self, v, op):
"""
Case handled:
- :code:`v` is a scalar (:code:`float`, :code:`int`);
- :code:`v` is a numpy array (NOTE: :code:`v` will be overwritten)
- :code:`v` is a :code:`dolfin.Vector` (NOTE: :code:`v` will be overwritten)
Operation: :code:`op = "Sum"` or `"Avg"` (case insentive).
"""
op = op.lower()
if type(v) in [float, np.float64]:
v_array = np.array([v], dtype=np.float64)
self._allReduce_array(v_array, op)
return v_array[0]
elif type(v) in [int, np.int32]:
v_array = np.array([v], dtype=np.int32)
self._allReduce_array(v_array, op)
return v_array[0]
elif (type(v) is np.array) or (type(v) is np.ndarray):
return self._allReduce_array(v,op)
elif hasattr(v, "mpi_comm") and hasattr(v, "get_local"):
# v is most likely a dl.Vector
if self.is_serial_check:
assert v.mpi_comm().Get_size() == 1
v_array = v.get_local()
self._allReduce_array(v_array,op)
v.set_local(v_array)
v.apply("")
return v
elif hasattr(v,'nvec'):
for i in range(v.nvec()):
self.allReduce(v[i],op)
return v
else:
if self.is_serial_check:
msg = "MultipleSerialPDEsCollective.allReduce not implement for v of type {0}".format(type(v))
else:
msg = "MultipleSamePartitioningPDEsCollective.allReduce not implement for v of type {0}".format(type(v))
raise NotImplementedError(msg)
[docs] def bcast(self, v, root = 0):
"""
Case handled:
- :code:`v` is a scalar (:code:`float`, :code:`int`);
- :code:`v` is a numpy array (NOTE: :code:`v` will be overwritten)
- :code:`v` is a :code:`dolfin.Vector` (NOTE: :code:`v` will be overwritten)
- :code:`root` refers to the process rank within the communicator for which the data to be
broadcasted lives.
"""
if type(v) in [float, np.float64, int, np.int32]:
v_array = np.array([v])
self.comm.Bcast(v_array,root = root)
return v_array[0]
if type(v) in [np.array, np.ndarray]:
self.comm.Bcast(v,root = root)
return v
elif hasattr(v, "mpi_comm") and hasattr(v, "get_local"):
# v is most likely a dl.Vector
if self.is_serial_check:
assert v.mpi_comm().Get_size() == 1
v_local = v.get_local()
self.comm.Bcast(v_local, root = root)
v.set_local(v_local)
v.apply("")
return v
elif hasattr(v,'nvec'):
for i in range(v.nvec()):
self.bcast(v[i],root=root)
return v
else:
if is_serial_check:
msg = "MultipleSerialPDEsCollective.bcast not implement for v of type {0}".format(type(v))
else:
msg = "MultipleSamePartitioningPDEsCollective.bcast not implement for v of type {0}".format(type(v))
raise NotImplementedError(msg)
[docs]def MultipleSerialPDEsCollective(comm):
return MultipleSamePartitioningPDEsCollective(comm, is_serial_check=True)