//    MCL: MiMiC Communication Library
//    Copyright (C) 2015-2025  The MiMiC Authors (see CONTRIBUTORS file for details).
//
//    This file is part of MCL.
//
//    MCL is free software: you can redistribute it and/or modify
//    it under the terms of the GNU Lesser General Public License as
//    published by the Free Software Foundation, either version 3 of
//    the License, or (at your option) any later version.
//
//    MCL is distributed in the hope that it will be useful, but
//    WITHOUT ANY WARRANTY; without even the implied warranty of
//    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
//    GNU Lesser General Public License for more details.
//
//    You should have received a copy of the GNU Lesser General Public License
//    along with this program.  If not, see <http://www.gnu.org/licenses/>.

#include "mpi_port_transport.h"

#include <stdlib.h>

#include <fstream>

#include <gmock/gmock.h>
#include <gtest/gtest.h>
#include "mpi.h"

#include "data_types.h"
#include "env_controls.h"
#include "error_codes.h"
#include "error_print.h"
#include "test_macros.h"

using namespace mcl;
using ::testing::ContainsRegex;
using ::testing::TestEventListeners;
using ::testing::UnitTest;
using ::testing::ElementsAreArray;
using ::testing::Not;
using ::testing::internal::CaptureStderr;
using ::testing::internal::GetCapturedStderr;

class MpiPortTransportTest : public ::testing::Test {
  protected:
    static MPI_Comm         world_comm;       // MPI_COMM_WORLD
    static MPI_Comm         mini_world_comm;  // emulates MPI_COMM_WORLD in a run using MPI-type transport
    static int              world_rank;       // rank in world_comm
    static int              rank;             // rank in mini_world_comm
    static MpiPortTransport transport;        // 

    int program_id; // 

    static void SetUpTestSuite() {
       world_comm = MPI_COMM_WORLD;

       int world_size, program_id_tmp;
       MPI_Comm_size(world_comm, &world_size);
       if (world_size != 7)
           throw std::runtime_error("Tests must be run with 7 processes!");

       MPI_Comm_rank(world_comm, &world_rank);
       if (world_rank >= 0 && world_rank < 3) program_id_tmp = 0;
       if (world_rank >= 3 && world_rank < 5) program_id_tmp = 1;
       if (world_rank >= 5 && world_rank < 8) program_id_tmp = 2;
           
       SET_ENV_CONTROL(env_program_number, std::to_string(program_id_tmp).c_str());
       SET_ENV_CONTROL(env_port_directory, ".");

       MPI_Comm_split(world_comm, program_id_tmp, 0, &mini_world_comm);
       MPI_Comm_rank(mini_world_comm, &rank);
    }

    void SetUp() override {
        if (world_rank == 0) {
            std::fstream f;
            f.open(".MCL_portname_1", std::fstream::out);
            f.write("12345", 5);
            f.close();
        }
        MPI_Barrier(world_comm);
        transport = MpiPortTransport(&mini_world_comm);
        ASSERT_EQ(transport.Initialize(&mini_world_comm), 0);
    }

    void TearDown() override {
        transport.Finalize();
    }

};

int               MpiPortTransportTest::world_rank;
int               MpiPortTransportTest::rank;
MPI_Comm          MpiPortTransportTest::world_comm;
MPI_Comm          MpiPortTransportTest::mini_world_comm;
MpiPortTransport  MpiPortTransportTest::transport(&mini_world_comm);


TEST_F(MpiPortTransportTest, Initialize) {
    program_id = transport.program_id();

    int num_programs = transport.num_programs();
    ASSERT_EQ(num_programs, 3);

    // server
    if (program_id == 0) {
        ASSERT_FALSE(transport.is_client());
        ASSERT_TRUE((transport.is_server()) == (rank == 0));
    } 
    // clients
    if (program_id == 1 || program_id == 2) {
        ASSERT_FALSE(transport.is_server());
        ASSERT_TRUE((transport.is_client()) == (rank == 0));
    }
}

TEST_F(MpiPortTransportTest, ProbeSize) {
    constexpr int send_size[2] = {5, 3};

    if (rank == 0) {
        // server
        if (program_id == 0) {
            for (int i = 1; i <= 2; ++i) {
                int size = transport.ProbeSize(kTypeInt, i);
                ASSERT_EQ(size, send_size[i-1]);
                int *recv_buffer = new int[size];
                transport.Receive(recv_buffer, size, kTypeInt, 0, i);
                delete[] recv_buffer;
            }
        }
        // clients
        if (program_id == 1 || program_id == 2) {
            std::vector<int> send_data(send_size[program_id-1]);
            transport.Send(send_data.data(), send_data.size(), kTypeInt, 0, 0);
        }
    }
}

TEST_F(MpiPortTransportTest, PingPong) {
    int     ref_data_int[5] = {1, 2, 3, 4, 5};
    float   ref_data_flt[5] = {1.1, 2.2, 3.3, 4.4, 5.5};
    double  ref_data_dbl[5] = {1.1, 2.2, 3.3, 4.4, 5.5};

    // server
    if (program_id == 0) {
        for (int i = 1; i <= 2; ++i) {
            int    recv_data_int[5] = {0, 0, 0, 0, 0};
            float  recv_data_flt[5] = {0.0, 0.0, 0.0, 0.0, 0.0};
            double recv_data_dbl[5] = {0.0, 0.0, 0.0, 0.0, 0.0};

            ASSERT_THAT(recv_data_int, Not(ElementsAreArray(ref_data_int)));
            ASSERT_THAT(recv_data_flt, Not(ElementsAreArray(ref_data_flt)));
            ASSERT_THAT(recv_data_dbl, Not(ElementsAreArray(ref_data_dbl)));

            if (rank == 0) {
                transport.Send(ref_data_int, 5, kTypeInt, 0, i);
                transport.Receive(recv_data_int, 5, kTypeInt, 0, i);
                ASSERT_THAT(recv_data_int, ElementsAreArray(ref_data_int));

                transport.Send(ref_data_flt, 5, kTypeFloat, 0, i);
                transport.Receive(recv_data_flt, 5, kTypeFloat, 0, i);
                ASSERT_THAT(recv_data_flt, ElementsAreArray(ref_data_flt));

                transport.Send(ref_data_dbl, 5, kTypeDouble, 0, i);
                transport.Receive(recv_data_dbl, 5, kTypeDouble, 0, i);
                ASSERT_THAT(recv_data_dbl, ElementsAreArray(ref_data_dbl));
            } else {
                ASSERT_THAT(recv_data_int, Not(ElementsAreArray(ref_data_int)));
                ASSERT_THAT(recv_data_flt, Not(ElementsAreArray(ref_data_flt)));
                ASSERT_THAT(recv_data_dbl, Not(ElementsAreArray(ref_data_dbl)));
            }

            MPI_Bcast(recv_data_int, 5, MPI_INT, 0, mini_world_comm);
            MPI_Bcast(recv_data_flt, 5, MPI_FLOAT, 0, mini_world_comm);
            MPI_Bcast(recv_data_dbl, 5, MPI_DOUBLE, 0, mini_world_comm);
            ASSERT_THAT(recv_data_int, ElementsAreArray(ref_data_int));
            ASSERT_THAT(recv_data_flt, ElementsAreArray(ref_data_flt));
            ASSERT_THAT(recv_data_dbl, ElementsAreArray(ref_data_dbl));
        }

        MPI_Barrier(world_comm);
    }

    // clients
    if (program_id == 1 || program_id == 2) {
        int     recv_data_int[5] = {0, 0, 0, 0, 0};
        float   recv_data_flt[5] = {0.0, 0.0, 0.0, 0.0, 0.0};
        double  recv_data_dbl[5] = {0.0, 0.0, 0.0, 0.0, 0.0};
        
        ASSERT_THAT(recv_data_int, Not(ElementsAreArray(ref_data_int)));
        ASSERT_THAT(recv_data_flt, Not(ElementsAreArray(ref_data_flt)));
        ASSERT_THAT(recv_data_dbl, Not(ElementsAreArray(ref_data_dbl)));
        if (rank == 0) {
            transport.Receive(recv_data_int, 5, kTypeInt, 0, 0);
            ASSERT_THAT(recv_data_int, ElementsAreArray(ref_data_int));
            transport.Send(ref_data_int, 5, kTypeInt, 0, 0);

            transport.Receive(recv_data_flt, 5, kTypeFloat, 0, 0);
            ASSERT_THAT(recv_data_flt, ElementsAreArray(ref_data_flt));
            transport.Send(ref_data_flt, 5, kTypeFloat, 0, 0);

            transport.Receive(recv_data_dbl, 5, kTypeDouble, 0, 0);
            ASSERT_THAT(recv_data_dbl, ElementsAreArray(ref_data_dbl));
            transport.Send(ref_data_dbl, 5, kTypeDouble, 0, 0);

        } else {
            ASSERT_THAT(recv_data_int, Not(ElementsAreArray(ref_data_int)));
            ASSERT_THAT(recv_data_flt, Not(ElementsAreArray(ref_data_flt)));
            ASSERT_THAT(recv_data_dbl, Not(ElementsAreArray(ref_data_dbl)));
        }

        MPI_Bcast(recv_data_int, 5, MPI_INT, 0, mini_world_comm);
        MPI_Bcast(recv_data_flt, 5, MPI_FLOAT, 0, mini_world_comm);
        MPI_Bcast(recv_data_dbl, 5, MPI_DOUBLE, 0, mini_world_comm);
        ASSERT_THAT(recv_data_int, ElementsAreArray(ref_data_int));
        ASSERT_THAT(recv_data_flt, ElementsAreArray(ref_data_flt));
        ASSERT_THAT(recv_data_dbl, ElementsAreArray(ref_data_dbl));

        MPI_Barrier(world_comm);
    }
}


class MpiPortInitializeTest : public MpiPortTransportTest {
    void SetUp() override {
        MPI_Barrier(world_comm);
        transport = MpiPortTransport(&mini_world_comm);
    }

    void TearDown() override { transport.Finalize(); }
};

TEST_F(MpiPortInitializeTest, NegativeTimeoutValue) {
    SET_ENV_CONTROL(env_timeout_value, std::to_string(-1).c_str());
    CaptureStderr();
    ASSERT_EQ(transport.Initialize(&mini_world_comm), kSuccess);
    std::string error = GetCapturedStderr();

    ASSERT_THAT(error, ContainsRegex("~~~ MCL ~ WARNING ~~~"));
    ASSERT_THAT(error, ContainsRegex("The requested environment variable, " + AddQuotes(env_timeout_value.name()) +
                                     " has a negative value!"));
    ASSERT_THAT(error, ContainsRegex("Using the default value instead!"));
    transport.Finalize();
}

TEST_F(MpiPortInitializeTest, Timeout) {
    SET_ENV_CONTROL(env_timeout_value, std::to_string(0).c_str());
    CaptureStderr();
    ASSERT_EQ(transport.Initialize(&mini_world_comm), kErrMpiTimeout);
    std::string error = GetCapturedStderr();
    ASSERT_THAT(error, ContainsRegex("~~~ MCL ~ ERROR ~~~"));
    ASSERT_THAT(error, ContainsRegex("Failed to initialize! Initialize in program no."));
    ASSERT_THAT(error, ContainsRegex("timed out!"));
}

TEST_F(MpiPortInitializeTest, InvalidEnvVarValues) {
    SET_ENV_CONTROL(env_timeout_value, "abcd");
    CaptureStderr();
    ASSERT_EQ(transport.Initialize(&mini_world_comm), kErrInvalidEnvVar);
    std::string error = GetCapturedStderr();
    ASSERT_THAT(error, ContainsRegex("~~~ MCL ~ ERROR ~~~"));
    ASSERT_THAT(error, ContainsRegex("The environment variable, " + AddQuotes(env_timeout_value.name()) +
                                     " is present but invalid!"));
    ASSERT_THAT(error, ContainsRegex("Must be an integer."));
    
    SET_ENV_CONTROL(env_timeout_value, "20");
    SET_ENV_CONTROL(env_program_number, "abcd");
    CaptureStderr();
    ASSERT_EQ(transport.Initialize(&mini_world_comm), kErrInvalidEnvVar);
    error = GetCapturedStderr();
    ASSERT_THAT(error, ContainsRegex("~~~ MCL ~ ERROR ~~~"));
    ASSERT_THAT(error, ContainsRegex("The environment variable, " + AddQuotes(env_program_number.name()) +
                                     " is present but invalid!"));
    ASSERT_THAT(error, ContainsRegex("Must be an integer."));
}


int main(int argc, char** argv) {
    ::testing::InitGoogleTest(&argc, argv);
    MPI_Init(&argc, &argv);

    int world_rank;
    MPI_Comm_rank(MPI_COMM_WORLD, &world_rank);
    TestEventListeners& listeners = UnitTest::GetInstance()->listeners();
    if (world_rank != 0) 
        delete listeners.Release(listeners.default_result_printer());
    
    auto result = RUN_ALL_TESTS();

    MPI_Finalize();
    return result;
}
