Overriding Collective Functions

From Mpich2

Revision as of 16:47, 22 January 2008 by Goodell (Talk | contribs)
(diff) ← Older revision | Current revision (diff) | Newer revision → (diff)
Jump to: navigation, search


The communicator struct has a field called coll_fns, of type MPID_Collops (defined in src/include/mpiimpl.h), that points to a table of collective functions. By default this field is NULL. If coll_fns is NULL or an entry in the coll_fns table is NULL, the default collective function will be used.

So, in order to override one or more collective functions for a communicator, create an instance of the MPID_Collops table, and fill in the entries to point to your new functions. Then, in each communicator for which you want to override the functions, set coll->coll_fns to point to that table. Because the default function will be used for any NULL entry, you need only fill in the entries for the functions you want to override, and set the others to NULL.

To set the coll_fns function when the communicator is created, define MPID_Dev_comm_create_hook() to set it. You can also define MPID_Dev_comm_destroy_hook(). You'll also need to define HAVE_DEV_COMM_HOOK.

Example

Here's an example for overriding barrier for all communicators.

Define the hooks in mpidi_ch3_pre.h


#define HAVE_DEV_COMM_HOOK
#define MPID_Dev_comm_create_hook(comm) do {         \
        int mpi_errno;                               \
        mpi_errno = MPIDI_CH3I_comm_create(comm);    \
        if (mpi_errno) MPIU_ERR_POP(mpi_errno);      \
    } while(0)

#define MPID_Dev_comm_destroy_hook(comm) do {        \
        int mpi_errno;                               \
        mpi_errno = MPIDI_CH3I_comm_destroy(comm);   \
        if (mpi_errno) MPIU_ERR_POP(mpi_errno);      \
    } while(0)

Implement the functions a .c file


static int barrier(MPID_Comm *comm_ptr);

static MPID_Collops collective_functions = {
    0,    /* ref_count */
    barrier, /* Barrier */
    NULL, /* Bcast */
    NULL, /* Gather */
    NULL, /* Gatherv */
    NULL, /* Scatter */
    NULL, /* Scatterv */
    NULL, /* Allgather */
    NULL, /* Allgatherv */
    NULL, /* Alltoall */
    NULL, /* Alltoallv */
    NULL, /* Alltoallw */
    NULL, /* Reduce */
    NULL, /* Allreduce */
    NULL, /* Reduce_scatter */
    NULL, /* Scan */
    NULL  /* Exscan */
};

int MPIDI_CH3I_comm_create(MPID_Comm *comm)
{
    comm->coll_fns = &collective_functions;
    
    return MPI_SUCCESS;
}

int MPIDI_CH3I_comm_destroy(MPID_Comm *comm)
{
    return MPI_SUCCESS;
}

static int barrier(MPID_Comm *comm_ptr)
{
   /* New barrier implementation */
}