//    MCL: MiMiC Communication Library
//    Copyright (C) 2015-2025  The MiMiC Authors (see CONTRIBUTORS file for details).
//
//    This file is part of MCL.
//
//    MCL is free software: you can redistribute it and/or modify
//    it under the terms of the GNU Lesser General Public License as
//    published by the Free Software Foundation, either version 3 of
//    the License, or (at your option) any later version.
//
//    MCL is distributed in the hope that it will be useful, but
//    WITHOUT ANY WARRANTY; without even the implied warranty of
//    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
//    GNU Lesser General Public License for more details.
//
//    You should have received a copy of the GNU Lesser General Public License
//    along with this program.  If not, see <http://www.gnu.org/licenses/>.

#include "stub_transport.h"

#include <filesystem>
#include <string>

#include "mpi.h"
#include <gmock/gmock.h>
#include <gtest/gtest.h>

#include "data_types.h"
#include "env_controls.h"
#include "error_codes.h"
#include "test_macros.h"
#include "transfer_logger.h"

using namespace mcl;
using ::testing::ContainsRegex;
using ::testing::ElementsAreArray;
using ::testing::TestEventListeners;
using ::testing::UnitTest;
using ::testing::internal::CaptureStderr;
using ::testing::internal::GetCapturedStderr;

class StubTransportTest : public ::testing::Test {
  protected:
    int world_rank;
    StubTransport *transport;
    std::string file_name = "MCL_stub_log";
    std::vector<int> ref_ints = {5, 2, 1, 4, 9, 10, 23};
    int tag = 1;
    int id = 10;
    int client_id = 5;
    int server_id = 0;
    int size = static_cast<int>(ref_ints.size());

    void SetUp() override {
        MPI_Barrier(MPI_COMM_WORLD);
        MPI_Comm_rank(MPI_COMM_WORLD, &world_rank);
        if (world_rank == 0) {
            if (std::filesystem::exists(file_name))
                std::remove(file_name.c_str());

            TransferLogger *logger = new TransferLogger(file_name, WRITE);
            int type = kTypeInt;
            bool skipped_send = false;

            ASSERT_EQ(logger->Write(&skipped_send, 1, kTypeChar, kSkipSendTag, kSkipSendTag), kSuccess);
            ASSERT_EQ(logger->Write(ref_ints.data(), size, type, tag, id), kSuccess);
            ASSERT_EQ(logger->Write(reinterpret_cast<void*>(&size), 1, type, kProbeTag, id), kSuccess);
            delete logger;
        }

        transport = new StubTransport();
        
        int client_id = 5;
        SET_ENV_CONTROL(env_program_number, std::to_string(client_id).c_str());
        SET_ENV_CONTROL(env_test_data_file, file_name.c_str());
    }

    void TearDown() override {
        MPI_Barrier(MPI_COMM_WORLD);
        if (world_rank == 0) std::remove(file_name.c_str());
        delete transport; 
    }
    
};


TEST_F(StubTransportTest, InitializeServer) {
    SET_ENV_CONTROL(env_program_number, std::to_string(server_id).c_str());
    ASSERT_EQ(transport->Initialize(nullptr), kSuccess);
    ASSERT_EQ(transport->num_programs(), 1);
    ASSERT_EQ(transport->program_id(), server_id);
    ASSERT_TRUE(transport->is_server() == (world_rank == 0));
    ASSERT_TRUE(!transport->is_client());
}

TEST_F(StubTransportTest, InitializeClient) {
    ASSERT_EQ(transport->Initialize(nullptr), kSuccess);
    ASSERT_EQ(transport->program_id(), client_id);
    ASSERT_EQ(transport->num_programs(), 1);
    ASSERT_TRUE(!transport->is_server());
    ASSERT_TRUE(transport->is_client() == (world_rank == 0));
}

TEST_F(StubTransportTest, InitializeNoFile) {
    SET_ENV_CONTROL(env_test_data_file, "no_file");
    CaptureStderr();
    int result = transport->Initialize(nullptr);
    std::string error = GetCapturedStderr();

    if (world_rank == 0) {
        ASSERT_EQ(result, kErrFileMissing);
        ASSERT_THAT(error, ContainsRegex("~~~ MCL ~ ERROR ~~~"));
        ASSERT_THAT(error, ContainsRegex("File \"no_file\" doesn't exist!"));
    } else {
        ASSERT_EQ(result, kSuccess);
    }
}

TEST_F(StubTransportTest, FinalizeClient) {
    ASSERT_EQ(transport->Initialize(nullptr), kSuccess);
    if (transport->is_client()) {
        ASSERT_EQ(transport->Finalize(), kSuccess);
    }
}

TEST_F(StubTransportTest, Send) {
    transport->Initialize(nullptr); 
    if (transport->is_client()) {
        ASSERT_EQ(transport->Send(ref_ints.data(), size, kTypeInt, tag, id), kSuccess);
    }
}

TEST_F(StubTransportTest, SendMismatch) {
    transport->Initialize(nullptr); 
    if (transport->is_client()) {
        CaptureStderr();
        ASSERT_EQ(transport->Send(ref_ints.data(), size, kTypeInt, tag, id-1), kErrLogMismatch);
        std::string error = GetCapturedStderr();
        ASSERT_THAT(error, ContainsRegex("~~~ MCL ~ ERROR ~~~"));
        ASSERT_THAT(error, ContainsRegex("Mismatch between the call and logged data!"));
    }
}

TEST_F(StubTransportTest, Receive) {
    transport->Initialize(nullptr); 
    std::vector<int> ints(size);
    if (transport->is_client()) {
        ASSERT_EQ(transport->Receive(ints.data(), size, kTypeInt, tag, id), kSuccess);
        ASSERT_THAT(ints, ElementsAreArray(ref_ints));
    }
}

TEST_F(StubTransportTest, ReceiveMismatch) {
    transport->Initialize(nullptr); 
    std::vector<int> ints(size);
    if (transport->is_client()) {
        CaptureStderr();
        ASSERT_EQ(transport->Receive(ints.data(), size, kTypeInt, tag, id-1), kErrLogMismatch);
        std::string error = GetCapturedStderr();
        ASSERT_THAT(error, ContainsRegex("~~~ MCL ~ ERROR ~~~"));
        ASSERT_THAT(error, ContainsRegex("Mismatch between the call and logged data!"));
    }
}

TEST_F(StubTransportTest, ProbeSize) {
    transport->Initialize(nullptr); 
    std::vector<int> ints(size);
    if (transport->is_client()) {
        ASSERT_EQ(transport->Receive(ints.data(), size, kTypeInt, tag, id), kSuccess);
        ASSERT_EQ(transport->ProbeSize(kTypeInt, id), size);
    }
}

TEST_F(StubTransportTest, ProbeSizeMismatch) {
    transport->Initialize(nullptr); 
    if (transport->is_client()) {
        CaptureStderr();
        ASSERT_EQ(transport->ProbeSize(kTypeInt, id), kErrLogMismatch);
        std::string error = GetCapturedStderr();
        ASSERT_THAT(error, ContainsRegex("~~~ MCL ~ ERROR ~~~"));
        ASSERT_THAT(error, ContainsRegex("Mismatch between the call and logged data!"));
    }
}


int main(int argc, char** argv) {
    ::testing::InitGoogleTest(&argc, argv);
    MPI_Init(&argc, &argv);

    int world_rank;
    MPI_Comm_rank(MPI_COMM_WORLD, &world_rank);
    TestEventListeners& listeners = UnitTest::GetInstance()->listeners();
    if (world_rank != 0)
        delete listeners.Release(listeners.default_result_printer());

    auto result = RUN_ALL_TESTS();
    
    MPI_Finalize();
    return result;
}
