//    MCL: MiMiC Communication Library
//    Copyright (C) 2015-2025  The MiMiC Authors (see CONTRIBUTORS file for details).
//
//    This file is part of MCL.
//
//    MCL is free software: you can redistribute it and/or modify
//    it under the terms of the GNU Lesser General Public License as
//    published by the Free Software Foundation, either version 3 of
//    the License, or (at your option) any later version.
//
//    MCL is distributed in the hope that it will be useful, but
//    WITHOUT ANY WARRANTY; without even the implied warranty of
//    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
//    GNU Lesser General Public License for more details.
//
//    You should have received a copy of the GNU Lesser General Public License
//    along with this program.  If not, see <http://www.gnu.org/licenses/>.

#include "mpi_mpmd_transport.h"

#include <gmock/gmock.h>
#include <gtest/gtest.h>
#include "mpi.h"

#include "data_types.h"
#include "error_codes.h"

using namespace mcl;
using ::testing::ContainsRegex;
using ::testing::TestEventListeners;
using ::testing::UnitTest;
using ::testing::ElementsAreArray;
using ::testing::Not;
using ::testing::ExitedWithCode;
using ::testing::internal::CaptureStderr;
using ::testing::internal::GetCapturedStderr;

class MpiMpmdTransportTest : public ::testing::Test {
  protected:
    static MPI_Comm          world_comm; // MPI_COMM_WORLD
    static MpiMpmdTransport  transport;

    MPI_Comm local_comm;
    int local_rank;
    int program_id;

    static void SetUpTestSuite() { world_comm = MPI_COMM_WORLD; }

    void SetUp() override {
        MPI_Barrier(world_comm);
        local_comm = world_comm;
        transport = MpiMpmdTransport(&local_comm);
        ASSERT_EQ(transport.Initialize(&local_comm), 0);
        MPI_Comm_rank(local_comm, &local_rank);
        program_id = transport.program_id();
    }
};

MPI_Comm         MpiMpmdTransportTest::world_comm;
MpiMpmdTransport MpiMpmdTransportTest::transport(&world_comm);
 

TEST_F(MpiMpmdTransportTest, Initialize) {
    int num_programs = transport.num_programs();
    ASSERT_EQ(num_programs, 3);

    // server
    if (program_id == 0) {
        ASSERT_FALSE(transport.is_client());
        ASSERT_TRUE((transport.is_server()) == (local_rank == 0));
    } 
    // clients
    if (program_id == 1 || program_id == 2) {
        ASSERT_FALSE(transport.is_server());
        ASSERT_TRUE((transport.is_client()) == (local_rank == 0));
    }
}

TEST_F(MpiMpmdTransportTest, PureVirtual) {
    ASSERT_EQ(transport.Finalize(), 0);
}

TEST_F(MpiMpmdTransportTest, ProbeSize) {
    constexpr int send_size[2] = {5, 3};

    if (local_rank == 0) {
        // server
        if (program_id == 0) {
            for (int i = 1; i <= 2; ++i) {
                int size = transport.ProbeSize(kTypeInt, i);
                ASSERT_EQ(size, send_size[i - 1]);
                int* recv_buffer = new int[size];
                transport.Receive(recv_buffer, size, kTypeInt, 0, i);
                delete[] recv_buffer;
            }
        }
        // clients
        if (program_id == 1 || program_id == 2) {
            std::vector<int> send_data(send_size[program_id - 1]);
            transport.Send(send_data.data(), send_data.size(), kTypeInt, 0, 0);
        }
    }
}

TEST_F(MpiMpmdTransportTest, PingPong) {
    int     ref_data_int[5] = {1, 2, 3, 4, 5};
    float   ref_data_flt[5] = {1.1, 2.2, 3.3, 4.4, 5.5};
    double  ref_data_dbl[5] = {1.1, 2.2, 3.3, 4.4, 5.5};

    // server
    if (program_id == 0) {
        for (int i = 1; i <= 2; ++i) {
            int     recv_data_int[5] = {0, 0, 0, 0, 0};
            float   recv_data_flt[5] = {0.0, 0.0, 0.0, 0.0, 0.0};
            double  recv_data_dbl[5] = {0.0, 0.0, 0.0, 0.0, 0.0};
            
            ASSERT_THAT(recv_data_int, Not(ElementsAreArray(ref_data_int)));
            ASSERT_THAT(recv_data_flt, Not(ElementsAreArray(ref_data_flt)));
            ASSERT_THAT(recv_data_dbl, Not(ElementsAreArray(ref_data_dbl)));
            if (local_rank == 0) {
                transport.Send(ref_data_int, 5, kTypeInt, 0, i);
                transport.Receive(recv_data_int, 5, kTypeInt, 0, i);
                ASSERT_THAT(recv_data_int, ElementsAreArray(ref_data_int));

                transport.Send(ref_data_flt, 5, kTypeFloat, 0, i);
                transport.Receive(recv_data_flt, 5, kTypeFloat, 0, i);
                ASSERT_THAT(recv_data_flt, ElementsAreArray(ref_data_flt));

                transport.Send(ref_data_dbl, 5, kTypeDouble, 0, i);
                transport.Receive(recv_data_dbl, 5, kTypeDouble, 0, i);
                ASSERT_THAT(recv_data_dbl, ElementsAreArray(ref_data_dbl));
            } else {
                ASSERT_THAT(recv_data_int, Not(ElementsAreArray(ref_data_int)));
                ASSERT_THAT(recv_data_flt, Not(ElementsAreArray(ref_data_flt)));
                ASSERT_THAT(recv_data_dbl, Not(ElementsAreArray(ref_data_dbl)));
            }
            MPI_Bcast(recv_data_int, 5, MPI_INT, 0, local_comm);
            MPI_Bcast(recv_data_flt, 5, MPI_FLOAT, 0, local_comm);
            MPI_Bcast(recv_data_dbl, 5, MPI_DOUBLE, 0, local_comm);
            ASSERT_THAT(recv_data_int, ElementsAreArray(ref_data_int));
            ASSERT_THAT(recv_data_flt, ElementsAreArray(ref_data_flt));
            ASSERT_THAT(recv_data_dbl, ElementsAreArray(ref_data_dbl));
        }
    }

    // clients
    if (program_id == 1 || program_id == 2) {
        int     recv_data_int[5] = {0, 0, 0, 0, 0};
        float   recv_data_flt[5] = {0.0, 0.0, 0.0, 0.0, 0.0};
        double  recv_data_dbl[5] = {0.0, 0.0, 0.0, 0.0, 0.0};
        
        ASSERT_THAT(recv_data_int, Not(ElementsAreArray(ref_data_int)));
        ASSERT_THAT(recv_data_flt, Not(ElementsAreArray(ref_data_flt)));
        ASSERT_THAT(recv_data_dbl, Not(ElementsAreArray(ref_data_dbl)));
        if (local_rank == 0) {
            transport.Receive(recv_data_int, 5, kTypeInt, 0, 0);
            ASSERT_THAT(recv_data_int, ElementsAreArray(ref_data_int));
            transport.Send(ref_data_int, 5, kTypeInt, 0, 0);

            transport.Receive(recv_data_flt, 5, kTypeFloat, 0, 0);
            ASSERT_THAT(recv_data_flt, ElementsAreArray(ref_data_flt));
            transport.Send(ref_data_flt, 5, kTypeFloat, 0, 0);

            transport.Receive(recv_data_dbl, 5, kTypeDouble, 0, 0);
            ASSERT_THAT(recv_data_dbl, ElementsAreArray(ref_data_dbl));
            transport.Send(ref_data_dbl, 5, kTypeDouble, 0, 0);
        } else {
            ASSERT_THAT(recv_data_int, Not(ElementsAreArray(ref_data_int)));
            ASSERT_THAT(recv_data_flt, Not(ElementsAreArray(ref_data_flt)));
            ASSERT_THAT(recv_data_dbl, Not(ElementsAreArray(ref_data_dbl)));
        }
        MPI_Bcast(recv_data_int, 5, MPI_INT, 0, local_comm);
        MPI_Bcast(recv_data_flt, 5, MPI_FLOAT, 0, local_comm);
        MPI_Bcast(recv_data_dbl, 5, MPI_DOUBLE, 0, local_comm);
        ASSERT_THAT(recv_data_int, ElementsAreArray(ref_data_int));
        ASSERT_THAT(recv_data_flt, ElementsAreArray(ref_data_flt));
        ASSERT_THAT(recv_data_dbl, ElementsAreArray(ref_data_dbl));
    }
    
    // test invalid type in PickMpiType
    if (program_id == 0 && local_rank == 0){
        int kTypeFake = -868676;
        CaptureStderr();
        ASSERT_EQ(transport.Send(ref_data_int, 5, kTypeFake, 0, 1), kErrUnknownType);
        std::string error = GetCapturedStderr();
        ASSERT_THAT(error, ContainsRegex("~~~ MCL ~ ERROR ~~~"));
        ASSERT_THAT(error, ContainsRegex("Unrecognized data type!"));
    }
}

int main(int argc, char** argv) {
    ::testing::InitGoogleTest(&argc, argv);
    MPI_Init(&argc, &argv);

    int world_rank;
    MPI_Comm_rank(MPI_COMM_WORLD, &world_rank);
    TestEventListeners& listeners = UnitTest::GetInstance()->listeners();
    if (world_rank != 0) 
        delete listeners.Release(listeners.default_result_printer());

    auto result = RUN_ALL_TESTS();
    
    MPI_Finalize();
    return result;
}
