//    MCL: MiMiC Communication Library
//    Copyright (C) 2015-2025  The MiMiC Authors (see CONTRIBUTORS file for details).
//
//    This file is part of MCL.
//
//    MCL is free software: you can redistribute it and/or modify
//    it under the terms of the GNU Lesser General Public License as
//    published by the Free Software Foundation, either version 3 of
//    the License, or (at your option) any later version.
//
//    MCL is distributed in the hope that it will be useful, but
//    WITHOUT ANY WARRANTY; without even the implied warranty of
//    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
//    GNU Lesser General Public License for more details.
//
//    You should have received a copy of the GNU Lesser General Public License
//    along with this program.  If not, see <http://www.gnu.org/licenses/>.

#include "mpi_mpmd_transport.h"

#include <vector>

#include "mpi.h"

#include "data_types.h"
#include "error_codes.h"
#include "error_print.h"

#define CHECK_MPI(FUNCTION, ERROR_CODE) { \
            int result = FUNCTION; \
            if (result != MPI_SUCCESS) { \
                PrintErrorMessage("There has been a problem in the MPI MPMD communication protocol!", \
                                  __FILE__, __LINE__, result); \
                return ERROR_CODE; \
            }\
        }

namespace mcl {

int MpiMpmdTransport::Initialize(void *args) {
    // communication tags
    static constexpr int kClientIdTag = 303;
    static constexpr int kClientListTag = 304;

    // split the world communicator
    int *color, flag;
    CHECK_MPI(MPI_Comm_get_attr(world_comm_, MPI_APPNUM, &color, &flag), kErrMpiInitFail);
    if (flag) {
        CHECK_MPI(MPI_Comm_split(world_comm_, *color, 0, &local_comm_), kErrMpiInitFail);
        CHECK_MPI(MPI_Comm_dup(local_comm_, static_cast<MPI_Comm *>(args)), kErrMpiInitFail);
        program_id_ = *color;
    }

    int world_rank, local_rank;
    MPI_Comm_rank(world_comm_, &world_rank);
    MPI_Comm_rank(local_comm_, &local_rank);
  
    CHECK_MPI(MPI_Allreduce(&program_id_, &num_programs_, 1, MPI_INT, MPI_MAX, world_comm_), kErrMpiInitFail);
    num_programs_++;
    
    std::vector<int> intercomm_ranks(num_programs_);

    if (world_rank == 0) {
        // identify clients
        intercomm_ranks[0] = world_rank;
        for (int i = 1; i < num_programs_; ++i) {
            MPI_Status status;
            int client_id, remote_client;
            
            CHECK_MPI(MPI_Probe(MPI_ANY_SOURCE, kClientIdTag, world_comm_, &status), kErrMpiInitFail);
            remote_client = status.MPI_SOURCE;
            CHECK_MPI(MPI_Recv(&client_id, 1, MPI_INT, remote_client, kClientIdTag, world_comm_, &status),
                      kErrMpiInitFail);
            intercomm_ranks[client_id] = remote_client;
        }

        // send client list
        for (int i = 1;  i < num_programs_; ++i)
            MPI_Send(intercomm_ranks.data(), num_programs_, MPI_INT, intercomm_ranks[i], kClientListTag, world_comm_);
    } else if (local_rank == 0) {
        MPI_Status status;
        CHECK_MPI(MPI_Send(&program_id_, 1, MPI_INT, 0, kClientIdTag, world_comm_), kErrMpiInitFail);
        CHECK_MPI(MPI_Recv(intercomm_ranks.data(), num_programs_, MPI_INT, 0, kClientListTag, world_comm_, &status),
                  kErrMpiInitFail);
    }
    
    // create intercommunicator
    if (local_rank == 0) {
        MPI_Group world;
        MPI_Comm_group(world_comm_, &world);
        MPI_Group intercomm_group;
        CHECK_MPI(MPI_Group_incl(world, num_programs_, intercomm_ranks.data(), &intercomm_group), kErrMpiInitFail);
        CHECK_MPI(MPI_Comm_create_group(world_comm_, intercomm_group, 1, &intercomm_), kErrMpiInitFail);
    }

    is_server_ = (world_rank == 0);
    is_client_ = (local_rank == 0) && (world_rank != 0);
    
    return kSuccess;

}
    
int MpiMpmdTransport::Abort(int error_code) {
    return MPI_Abort(world_comm_, error_code);
}


// Converts an MCL data type to an MPI data type.
MPI_Datatype PickMpiType(int data_type) {
    MPI_Datatype mpi_type;
    switch (data_type) {
      case kTypeChar:
        mpi_type = MPI_CHAR; break;
      case kTypeInt:
        mpi_type = MPI_INT; break;
      case kTypeLongInt:
        mpi_type = MPI_LONG; break;
      case kTypeFloat:
        mpi_type = MPI_FLOAT; break;
      case kTypeDouble:
        mpi_type = MPI_DOUBLE; break;
      default:
        PrintErrorMessage("Unrecognized data type!", __FILE__, __LINE__);
        mpi_type = MPI_BYTE;
    }
    return mpi_type;
}

int MpiMpmdTransport::Send(void *data, int length, int data_type, int tag, int destination) {
    MPI_Datatype mpi_type = PickMpiType(data_type);
    if (mpi_type == MPI_BYTE) return kErrUnknownType;
    CHECK_MPI(MPI_Send(data, length, PickMpiType(data_type), destination, tag, intercomm_), kErrMpiCommFail);
    return kSuccess;
}

int MpiMpmdTransport::Receive(void *data, int length, int data_type, int tag, int source) {
    MPI_Datatype mpi_type = PickMpiType(data_type);
    if (mpi_type == MPI_BYTE) return kErrUnknownType;
    MPI_Status status;
    CHECK_MPI(MPI_Recv(data, length, mpi_type, source, tag, intercomm_, &status), kErrMpiCommFail);
    return kSuccess;
}

int MpiMpmdTransport::ProbeSize(int data_type, int source) {
    MPI_Datatype mpi_type = PickMpiType(data_type);
    if (mpi_type == MPI_BYTE) return kErrUnknownType;
    int size;
    MPI_Status status;
    CHECK_MPI(MPI_Probe(source, MPI_ANY_TAG, intercomm_, &status), kErrMpiCommFail);
    CHECK_MPI(MPI_Get_count(&status, mpi_type, &size), kErrMpiCommFail);
    return size;
}

} // namespace mcl
