/**
 *    Copyright (C) 2019-present MongoDB, Inc.
 *
 *    This program is free software: you can redistribute it and/or modify
 *    it under the terms of the Server Side Public License, version 1,
 *    as published by MongoDB, Inc.
 *
 *    This program is distributed in the hope that it will be useful,
 *    but WITHOUT ANY WARRANTY; without even the implied warranty of
 *    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 *    Server Side Public License for more details.
 *
 *    You should have received a copy of the Server Side Public License
 *    along with this program. If not, see
 *    <http://www.mongodb.com/licensing/server-side-public-license>.
 *
 *    As a special exception, the copyright holders give permission to link the
 *    code of portions of this program with the OpenSSL library under certain
 *    conditions as described in each individual source file and distribute
 *    linked combinations including the program with the OpenSSL library. You
 *    must comply with the Server Side Public License in all respects for
 *    all of the code used other than as permitted herein. If you modify file(s)
 *    with this exception, you may extend this exception to your version of the
 *    file(s), but you are not obligated to do so. If you do not wish to do so,
 *    delete this exception statement from your version. If you delete this
 *    exception statement from all source files in the program, then also delete
 *    it in the license file.
 */

#define MONGO_LOGV2_DEFAULT_COMPONENT ::mongo::logv2::LogComponent::kTest

#include "mongo/platform/basic.h"

#include "mongo/logv2/log.h"
#include "mongo/platform/random.h"
#include "mongo/s/chunk_manager.h"
#include "mongo/s/chunk_writes_tracker.h"
#include "mongo/s/chunks_test_util.h"
#include "mongo/unittest/unittest.h"

namespace mongo {

using chunks_test_util::assertEqualChunkInfo;
using chunks_test_util::calculateCollVersion;
using chunks_test_util::calculateIntermediateShardKey;
using chunks_test_util::genChunkVector;
using chunks_test_util::genRandomSplitPoints;
using chunks_test_util::performRandomChunkOperations;

namespace {

PseudoRandom _random{SecureRandom().nextInt64()};

const NamespaceString kNss("TestDB", "TestColl");
const ShardId kThisShard("testShard");

ShardVersionMap getShardVersionMap(const ChunkMap& chunkMap) {
    return chunkMap.getShardVersionsMap();
}

std::map<ShardId, ChunkVersion> calculateShardVersions(
    const std::vector<std::shared_ptr<ChunkInfo>>& chunkVector) {
    std::map<ShardId, ChunkVersion> svMap;
    for (const auto& chunk : chunkVector) {
        auto mapIt = svMap.find(chunk->getShardId());
        if (mapIt == svMap.end()) {
            svMap.emplace(chunk->getShardId(), chunk->getLastmod());
            continue;
        }
        if (mapIt->second.isOlderThan(chunk->getLastmod())) {
            mapIt->second = chunk->getLastmod();
        }
    }
    return svMap;
}

std::vector<std::shared_ptr<ChunkInfo>> toChunkInfoPtrVector(
    const std::vector<ChunkType>& chunkTypes, bool initializeWriteTrackerRandom = true) {
    std::vector<std::shared_ptr<ChunkInfo>> chunkPtrs;
    chunkPtrs.reserve(chunkTypes.size());
    for (const auto& chunkType : chunkTypes) {
        auto chunkInfoPtr = std::make_shared<ChunkInfo>(chunkType);
        if (initializeWriteTrackerRandom) {
            chunkInfoPtr->getWritesTracker()->addBytesWritten(_random.nextInt64(30));
        }
        chunkPtrs.push_back(std::move(chunkInfoPtr));
    }
    return chunkPtrs;
}

void validateChunkMap(const ChunkMap& chunkMap,
                      const std::vector<std::shared_ptr<ChunkInfo>>& chunkInfoVector) {

    // The chunkMap should contain all the chunks
    ASSERT_EQ(chunkInfoVector.size(), chunkMap.size());

    // Check collection version
    const auto expectedShardVersions = calculateShardVersions(chunkInfoVector);
    const auto expectedCollVersion = calculateCollVersion(expectedShardVersions);
    ASSERT_EQ(expectedCollVersion, chunkMap.getVersion());

    size_t i = 0;
    chunkMap.forEach([&](const auto& chunkPtr) {
        const auto& expectedChunkPtr = chunkInfoVector[i++];
        // Check that the chunk pointer is valid
        ASSERT(chunkPtr.get() != nullptr);
        assertEqualChunkInfo(*expectedChunkPtr, *chunkPtr);
        return true;
    });

    // Validate all shard versions
    const auto shardVersions = getShardVersionMap(chunkMap);
    ASSERT_EQ(expectedShardVersions.size(), shardVersions.size());
    for (const auto& mapIt : shardVersions) {
        ASSERT_EQ(expectedShardVersions.at(mapIt.first), mapIt.second.shardVersion);
    }

    // Check that vectors are balanced in size
    auto maxVectorSize = static_cast<size_t>(std::lround(chunkMap.getMaxChunkVectorSize() * 1.5));
    auto minVectorSize = std::min(
        chunkMap.size(), static_cast<size_t>(std::lround(chunkMap.getMaxChunkVectorSize() / 2)));

    for (const auto& [maxKeyString, chunkVectorPtr] : chunkMap.getChunkVectorMap()) {
        ASSERT_GTE(chunkVectorPtr->size(), minVectorSize);
        ASSERT_LTE(chunkVectorPtr->size(), maxVectorSize);
    }
}

class ChunkMapTest : public unittest::Test {
public:
    const KeyPattern& getShardKeyPattern() const {
        return _shardKeyPattern;
    }

    const OID& collEpoch() const {
        return _epoch;
    }

    const boost::optional<Timestamp>& collTimestamp() const {
        return _collTimestamp;
    }

    ChunkMap makeChunkMap(const std::vector<std::shared_ptr<ChunkInfo>>& chunks) const {
        const auto chunkBucketSize =
            static_cast<size_t>(_random.nextInt64(chunks.size() * 1.2) + 1);
        LOGV2(7162701, "Creating new chunk map", "chunkBucketSize"_attr = chunkBucketSize);
        return ChunkMap{collEpoch(), collTimestamp(), chunkBucketSize}.createMerged(chunks);
    }

    std::vector<ChunkType> genRandomChunkVector(size_t maxNumChunks = 30,
                                                size_t minNumChunks = 1) const {
        return chunks_test_util::genRandomChunkVector(kNss, _epoch, maxNumChunks, minNumChunks);
    }

private:
    KeyPattern _shardKeyPattern{chunks_test_util::kShardKeyPattern};
    const OID _epoch{OID::gen()};
    const boost::optional<Timestamp> _collTimestamp;
};

TEST_F(ChunkMapTest, TestAddChunk) {
    ChunkVersion version{1, 0, collEpoch(), collTimestamp()};

    auto chunk = std::make_shared<ChunkInfo>(
        ChunkType{kNss,
                  ChunkRange{getShardKeyPattern().globalMin(), getShardKeyPattern().globalMax()},
                  version,
                  kThisShard});

    auto newChunkMap = makeChunkMap({chunk});

    ASSERT_EQ(newChunkMap.size(), 1);

    validateChunkMap(newChunkMap, {chunk});
}

TEST_F(ChunkMapTest, ConstructChunkMapRandom) {
    auto chunkVector = toChunkInfoPtrVector(genRandomChunkVector());

    const auto chunkMap = makeChunkMap(chunkVector);

    validateChunkMap(chunkMap, chunkVector);
}

TEST_F(ChunkMapTest, ConstructChunkMapRandomAllChunksSameVersion) {
    auto chunkVector = genRandomChunkVector();
    auto commonVersion = chunkVector.front().getVersion();

    // Set same version on all chunks
    for (auto& chunk : chunkVector) {
        chunk.setVersion(commonVersion);
    }

    auto chunkInfoVector = toChunkInfoPtrVector(chunkVector);
    const auto expectedShardVersions = calculateShardVersions(chunkInfoVector);
    const auto expectedCollVersion = calculateCollVersion(expectedShardVersions);

    ASSERT_EQ(commonVersion, expectedCollVersion);

    const auto chunkMap = makeChunkMap(chunkInfoVector);
    validateChunkMap(chunkMap, chunkInfoVector);
}

/*
 * Check that constucting a ChunkMap with chunks that have mismatching epoch fails.
 */
TEST_F(ChunkMapTest, ConstructChunkMapMismatchingEpochs) {
    auto chunkVector = toChunkInfoPtrVector(genRandomChunkVector());

    // Set a different epoch in one of the chunks
    const auto wrongEpoch = OID::gen();
    const auto wrongChunkIdx = _random.nextInt32(chunkVector.size());
    const auto oldChunk = chunkVector.at(wrongChunkIdx);
    const auto oldVersion = oldChunk->getLastmod();
    const ChunkVersion wrongVersion{
        oldVersion.majorVersion(), oldVersion.minorVersion(), wrongEpoch, collTimestamp()};
    chunkVector[wrongChunkIdx] = std::make_shared<ChunkInfo>(
        ChunkType{kNss, oldChunk->getRange(), wrongVersion, oldChunk->getShardId()});

    ASSERT_THROWS_CODE(
        makeChunkMap(chunkVector), AssertionException, ErrorCodes::ConflictingOperationInProgress);
}

TEST_F(ChunkMapTest, UpdateMapNotLeaveSmallVectors) {
    const ChunkVersion initialVersion{1, 0, collEpoch(), collTimestamp()};
    auto chunkVector = toChunkInfoPtrVector(
        genChunkVector(kNss, genRandomSplitPoints(8), initialVersion, 1 /*numShards*/));

    const auto chunkBucketSize = 4;
    LOGV2(7162703, "Constructing new chunk map", "chunkBucketSize"_attr = chunkBucketSize);
    const auto initialChunkMap =
        ChunkMap(collEpoch(), collTimestamp(), chunkBucketSize).createMerged(chunkVector);

    // Check that it contains all the chunks
    ASSERT_EQ(chunkVector.size(), initialChunkMap.size());

    auto mergedVersion = initialChunkMap.getVersion();
    mergedVersion.incMinor();

    auto mergedChunk = std::make_shared<ChunkInfo>(ChunkType{
        kNss,
        ChunkRange{chunkVector[4]->getRange().getMin(), chunkVector.back()->getRange().getMax()},
        mergedVersion,
        kThisShard});
    const auto chunkMap = initialChunkMap.createMerged({mergedChunk});

    // Check that vectors are balanced in size
    auto maxVectorSize = std::lround(chunkMap.getMaxChunkVectorSize() * 1.5);
    auto minVectorSize = std::min(
        chunkMap.size(), static_cast<size_t>(std::lround(chunkMap.getMaxChunkVectorSize() / 2)));

    for (const auto& [maxKeyString, chunkVectorPtr] : chunkMap.getChunkVectorMap()) {
        ASSERT_GTE(chunkVectorPtr->size(), minVectorSize);
        ASSERT_LTE(chunkVectorPtr->size(), maxVectorSize);
    }

    // Check original map is sitll valid
    validateChunkMap(initialChunkMap, chunkVector);
}


/*
 * Check that updating a ChunkMap with chunks that have mismatching epoch fails.
 */
TEST_F(ChunkMapTest, UpdateChunkMapMismatchingEpochs) {
    auto chunkVector = toChunkInfoPtrVector(genRandomChunkVector());

    auto chunkMap = makeChunkMap(chunkVector);
    auto collVersion = chunkMap.getVersion();

    // Set a different epoch in one of the chunks
    const auto wrongEpoch = OID::gen();
    const auto wrongChunkIdx = _random.nextInt32(chunkVector.size());
    const auto oldChunk = chunkVector.at(wrongChunkIdx);
    const ChunkVersion wrongVersion{
        collVersion.majorVersion(), collVersion.minorVersion(), wrongEpoch, collTimestamp()};
    auto updateChunk = std::make_shared<ChunkInfo>(
        ChunkType{kNss, oldChunk->getRange(), wrongVersion, oldChunk->getShardId()});

    ASSERT_THROWS_CODE(chunkMap.createMerged({updateChunk}),
                       AssertionException,
                       ErrorCodes::ConflictingOperationInProgress);
}

/*
 * Test update of ChunkMap with random chunk manipulation (splits/merges/moves);
 */
TEST_F(ChunkMapTest, UpdateChunkMapRandom) {
    auto initialChunks = genRandomChunkVector();
    auto initialChunksInfo = toChunkInfoPtrVector(initialChunks);

    const auto initialChunkMap = makeChunkMap(initialChunksInfo);

    const auto initialShardVersions = calculateShardVersions(initialChunksInfo);
    const auto initialCollVersion = calculateCollVersion(initialShardVersions);

    auto chunks = initialChunks;

    const auto maxNumChunkOps = 2 * initialChunks.size();
    const auto numChunkOps = _random.nextInt32(maxNumChunkOps);
    performRandomChunkOperations(&chunks, numChunkOps);

    auto chunksInfo = toChunkInfoPtrVector(initialChunks, false /* initializeWriteTrackerRandom */);

    std::vector<std::shared_ptr<ChunkInfo>> updatedChunksInfo;
    for (auto& chunkPtr : chunksInfo) {
        // First overlapping chunk in the initial vector
        const auto& overlapInitChunk =
            **std::lower_bound(initialChunksInfo.begin(),
                               initialChunksInfo.end(),
                               ShardKeyPattern::toKeyString(chunkPtr->getRange().getMin()),
                               [](const auto& chunkInfo, const std::string& shardKeyString) {
                                   return chunkInfo->getMaxKeyString() <= shardKeyString;
                               });
        // The new chunks inherits the written bytes from the first overlapping old chunk
        chunkPtr->getWritesTracker()->addBytesWritten(
            overlapInitChunk.getWritesTracker()->getBytesWritten());

        if (!chunkPtr->getLastmod().isOlderOrEqualThan(initialCollVersion)) {
            updatedChunksInfo.push_back(std::make_shared<ChunkInfo>(ChunkType{
                kNss, chunkPtr->getRange(), chunkPtr->getLastmod(), chunkPtr->getShardId()}));
        }
    }

    // Create updated chunk map and validate it
    auto chunkMap = initialChunkMap.createMerged(updatedChunksInfo);
    validateChunkMap(chunkMap, chunksInfo);

    // Check that the initialChunkMap is still valid and usable
    validateChunkMap(initialChunkMap, initialChunksInfo);
}

TEST_F(ChunkMapTest, TestEnumerateAllChunks) {
    ChunkVersion version{1, 0, collEpoch(), collTimestamp()};

    auto newChunkMap =
        makeChunkMap({std::make_shared<ChunkInfo>(
                          ChunkType{kNss,
                                    ChunkRange{getShardKeyPattern().globalMin(), BSON("a" << 0)},
                                    version,
                                    kThisShard}),

                      std::make_shared<ChunkInfo>(ChunkType{
                          kNss, ChunkRange{BSON("a" << 0), BSON("a" << 100)}, version, kThisShard}),

                      std::make_shared<ChunkInfo>(
                          ChunkType{kNss,
                                    ChunkRange{BSON("a" << 100), getShardKeyPattern().globalMax()},
                                    version,
                                    kThisShard})});

    int count = 0;
    auto lastMax = getShardKeyPattern().globalMin();

    newChunkMap.forEach([&](const auto& chunkInfo) {
        ASSERT(SimpleBSONObjComparator::kInstance.evaluate(chunkInfo->getMax() > lastMax));
        lastMax = chunkInfo->getMax();
        count++;

        return true;
    });

    ASSERT_EQ(count, newChunkMap.size());
}


TEST_F(ChunkMapTest, TestIntersectingChunk) {
    ChunkVersion version{1, 0, collEpoch(), collTimestamp()};

    auto newChunkMap =
        makeChunkMap({std::make_shared<ChunkInfo>(
                          ChunkType{kNss,
                                    ChunkRange{getShardKeyPattern().globalMin(), BSON("a" << 0)},
                                    version,
                                    kThisShard}),

                      std::make_shared<ChunkInfo>(ChunkType{
                          kNss, ChunkRange{BSON("a" << 0), BSON("a" << 100)}, version, kThisShard}),

                      std::make_shared<ChunkInfo>(
                          ChunkType{kNss,
                                    ChunkRange{BSON("a" << 100), getShardKeyPattern().globalMax()},
                                    version,
                                    kThisShard})});

    auto intersectingChunk = newChunkMap.findIntersectingChunk(BSON("a" << 50));

    ASSERT(intersectingChunk);
    ASSERT(
        SimpleBSONObjComparator::kInstance.evaluate(intersectingChunk->getMin() == BSON("a" << 0)));
    ASSERT(SimpleBSONObjComparator::kInstance.evaluate(intersectingChunk->getMax() ==
                                                       BSON("a" << 100)));

    // findIntersectingChunks returns last chunk if invoked with MaxKey
    intersectingChunk =
        newChunkMap.findIntersectingChunk(BSON("a" << getShardKeyPattern().globalMax()));
    ASSERT(SimpleBSONObjComparator::kInstance.evaluate(intersectingChunk->getMin() ==
                                                       BSON("a" << 100)));
    ASSERT(SimpleBSONObjComparator::kInstance.evaluate(intersectingChunk->getMax() ==
                                                       getShardKeyPattern().globalMax()));
}

TEST_F(ChunkMapTest, TestIntersectingChunkRandom) {
    auto chunks = toChunkInfoPtrVector(genRandomChunkVector());

    const auto chunkMap = makeChunkMap(chunks);

    auto targetChunkIt = chunks.begin() + _random.nextInt64(chunks.size());
    auto intermediateKey = calculateIntermediateShardKey(
        (*targetChunkIt)->getMin(), (*targetChunkIt)->getMax(), 0.2 /* minKeyProb */);

    auto intersectingChunkPtr = chunkMap.findIntersectingChunk(intermediateKey);
    assertEqualChunkInfo(**(targetChunkIt), *intersectingChunkPtr);
}

TEST_F(ChunkMapTest, TestEnumerateOverlappingChunks) {
    ChunkVersion version{1, 0, collEpoch(), collTimestamp()};

    auto newChunkMap =
        makeChunkMap({std::make_shared<ChunkInfo>(
                          ChunkType{kNss,
                                    ChunkRange{getShardKeyPattern().globalMin(), BSON("a" << 0)},
                                    version,
                                    kThisShard}),

                      std::make_shared<ChunkInfo>(ChunkType{
                          kNss, ChunkRange{BSON("a" << 0), BSON("a" << 100)}, version, kThisShard}),

                      std::make_shared<ChunkInfo>(
                          ChunkType{kNss,
                                    ChunkRange{BSON("a" << 100), getShardKeyPattern().globalMax()},
                                    version,
                                    kThisShard})});

    auto min = BSON("a" << -50);
    auto max = BSON("a" << 150);
    int count = 0;
    newChunkMap.forEachOverlappingChunk(min, max, true, [&](const auto& chunk) {
        count++;
        return true;
    });
    ASSERT_EQ(count, 3);

    min = BSON("a" << -50);
    max = BSON("a" << getShardKeyPattern().globalMax());
    count = 0;
    newChunkMap.forEachOverlappingChunk(min, max, false, [&](const auto& chunk) {
        count++;
        return true;
    });
    ASSERT_EQ(count, 3);

    min = BSON("a" << 50);
    max = BSON("a" << 100);
    count = 0;
    newChunkMap.forEachOverlappingChunk(min, max, true, [&](const auto& chunk) {
        count++;
        return true;
    });
    ASSERT_EQ(count, 2);

    min = BSON("a" << 50);
    max = BSON("a" << 100);
    count = 0;
    newChunkMap.forEachOverlappingChunk(min, max, false, [&](const auto& chunk) {
        count++;
        return true;
    });
    ASSERT_EQ(count, 1);
}

TEST_F(ChunkMapTest, ForEachNoShardKey) {
    auto chunks = toChunkInfoPtrVector(genRandomChunkVector());

    const auto chunkMap = makeChunkMap(chunks);

    auto lastChunkIdx = std::max(_random.nextInt64(chunks.size()), static_cast<int64_t>(1));

    int i = 0;
    chunkMap.forEach([&](const auto& chunkInfo) {
        assertEqualChunkInfo(*chunks[i], *chunkInfo);
        return ++i < lastChunkIdx;
    });

    ASSERT_EQ(i, lastChunkIdx);
}

TEST_F(ChunkMapTest, ForEachWithShardKey) {
    auto chunks = toChunkInfoPtrVector(genRandomChunkVector());

    const auto chunkMap = makeChunkMap(chunks);

    auto firstChunkIdx = static_cast<size_t>(_random.nextInt64(chunks.size()));
    const auto& firstChunk = chunks[firstChunkIdx];
    auto skey = calculateIntermediateShardKey(
        firstChunk->getMin(), firstChunk->getMax(), 0.2 /* minKeyProb */);

    size_t i = firstChunkIdx;
    auto lastChunkIdx = firstChunkIdx +
        std::max(_random.nextInt64(chunks.size() - firstChunkIdx), static_cast<int64_t>(1));
    chunkMap.forEach(
        [&](const auto& chunkInfo) {
            assertEqualChunkInfo(*chunks[i], *chunkInfo);
            return ++i < lastChunkIdx;
        },
        skey);

    ASSERT_EQ(i, lastChunkIdx);
}

TEST_F(ChunkMapTest, TestEnumerateOverlappingChunksRandom) {
    auto chunks = toChunkInfoPtrVector(genRandomChunkVector());

    const auto chunkMap = makeChunkMap(chunks);

    auto firstChunkIt = chunks.begin() + _random.nextInt64(chunks.size());
    auto lastChunkIt = firstChunkIt + _random.nextInt64(std::distance(firstChunkIt, chunks.end()));

    auto minBound = calculateIntermediateShardKey(
        (*firstChunkIt)->getMin(), (*firstChunkIt)->getMax(), 0.2 /* minKeyProb */);
    auto maxBound = calculateIntermediateShardKey(
        (*lastChunkIt)->getMin(), (*lastChunkIt)->getMax(), 0.2 /* minKeyProb */);

    auto it = firstChunkIt;
    chunkMap.forEachOverlappingChunk(minBound, maxBound, true, [&](const auto& chunkInfoPtr) {
        assertEqualChunkInfo(**(it++), *chunkInfoPtr);
        return true;
    });
    ASSERT_EQ(0, std::distance(it, std::next(lastChunkIt)));
}

}  // namespace

}  // namespace mongo
