Skip to content

Commit

Permalink
SplitDimensionM: heuristic update
Browse files Browse the repository at this point in the history
  • Loading branch information
v-Golubev committed Dec 23, 2024
1 parent 9ff5942 commit 78a43ee
Show file tree
Hide file tree
Showing 4 changed files with 61 additions and 31 deletions.
2 changes: 1 addition & 1 deletion src/common/snippets/docs/mha_optimization_guide.md
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ The supported by decomposition Transpose orders are defined by `TokenizeMHASnipp

[SplitDimensionM](../src/pass/split_dimension_m.cpp) splits M dimension of MHA in 2 parts (`batch_m` and `new_m`) by inserting Reshape on A input of the first Matmul and output of the second Matmul (the rest Subgraph's inputs are reshaped by Unsqueeze-like reshapes in order not to break subgraph semantic).
This optimization increases parallel work amount by `batch_m` times thus enabling a more efficient parallel execution in some cases.
The splitting is performed based on heuristic algorithm which can be found in `SplitDimensionM::get_splited_dimensions` method.
The splitting is performed based on heuristic algorithm which can be found in `SplitDimensionM::split` method.

Let's consider an example of the transformation:

Expand Down
13 changes: 12 additions & 1 deletion src/common/snippets/include/snippets/pass/split_dimension_m.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,18 @@ class SplitDimensionM: public CommonOptimizations::SubgraphPass {

private:
static std::shared_ptr<ov::op::v0::MatMul> get_matmul(const std::shared_ptr<op::Subgraph>& subgraph);
static std::pair<size_t, size_t> get_splited_dimensions(size_t batch_dim, size_t m_dim, size_t optimal_parallelism_work_amount);
/**
* @brief Contains splitM approaches allowing to get the batch ideally divisible by optimal_parallelism_work_amount
*/
static std::pair<size_t, size_t> compute_ideal_cases_heuristic(size_t batch_dim, size_t m_dim, size_t optimal_parallelism_work_amount);
/**
* @brief Aggressively splits m_dim to minimize kernel_m in order to reduce waiting time for idle threads at the last parallel loop iteration.
*/
static std::pair<size_t, size_t> compute_aggressive_heuristic(size_t batch_dim, size_t m_dim, size_t optimal_parallelism_work_amount);
/**
* @brief Conservatively splits m_dim to get the batch in (optimal_parallelism_work_amount, 2 * optimal_parallelism_work_amount) interval
*/
static std::pair<size_t, size_t> compute_conservative_heuristic(size_t batch_dim, size_t m_dim, size_t optimal_parallelism_work_amount);

void reshape_subgraph(const std::shared_ptr<op::Subgraph>& subgraph, const ov::Shape& shape, size_t batch_m_dim, size_t new_m_dim);

Expand Down
75 changes: 46 additions & 29 deletions src/common/snippets/src/pass/split_dimension_m.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@

#include "snippets/pass/split_dimension_m.hpp"

#include "snippets/utils/utils.hpp"
#include "snippets/itt.hpp"
#include "snippets/utils/utils.hpp"

namespace {
size_t get_dim_M(const ov::Shape& shape) {
Expand All @@ -31,45 +31,55 @@ bool SplitDimensionM::is_supported_matmul(const std::shared_ptr<const ov::Node>&
return matmul && !matmul->get_transpose_a() && !matmul->is_dynamic();
}

std::pair<size_t, size_t> SplitDimensionM::get_splited_dimensions(size_t batch_dim, size_t m_dim, size_t optimal_parallelism_work_amount) {
std::pair<size_t, size_t> splited = { 1, m_dim };

std::pair<size_t, size_t> SplitDimensionM::compute_ideal_cases_heuristic(size_t batch_dim, size_t m_dim, size_t optimal_parallelism_work_amount) {
// Ideal case #1: M can be split on the parts one of which complements the batch dimension to the optimal parallel work amount
// In this case, each thread will execute the Snippets kernel once
const size_t lower_bound = optimal_parallelism_work_amount / batch_dim;
if (lower_bound * batch_dim == optimal_parallelism_work_amount && m_dim % lower_bound == 0) {
splited.first = lower_bound;
splited.second = m_dim / lower_bound;
OPENVINO_ASSERT(splited.first * splited.second == m_dim, "Incorrect dimension M splitting!");
return splited;
}
if (lower_bound * batch_dim == optimal_parallelism_work_amount && m_dim % lower_bound == 0)
return std::make_pair(lower_bound, m_dim / lower_bound);

// Ideal case #2: M is divisible by optimal parallel work amount, and the new_m_dim is big enough
// In this case, each thread will execute the Snippets kernel 'batch_dim' times
if (m_dim % optimal_parallelism_work_amount == 0) {
const auto new_m_dim = m_dim / optimal_parallelism_work_amount;
const size_t min_kernel_m = 64;
if (new_m_dim >= min_kernel_m) {
splited.first = optimal_parallelism_work_amount;
splited.second = new_m_dim;
OPENVINO_ASSERT(splited.first * splited.second == m_dim, "Incorrect dimension M splitting!");
return splited;
}
if (new_m_dim >= min_kernel_m)
return std::make_pair(optimal_parallelism_work_amount, new_m_dim);
}

return std::make_pair(1, m_dim);
}

std::pair<size_t, size_t> SplitDimensionM::compute_conservative_heuristic(size_t batch_dim, size_t m_dim, size_t optimal_parallelism_work_amount) {
std::pair<size_t, size_t> splited = { 1, m_dim };
const size_t upper_bound = utils::div_up(2 * optimal_parallelism_work_amount, batch_dim);
for (size_t divisor_0 = upper_bound - 1; divisor_0 > 1; divisor_0--) {
size_t divisor_1 = m_dim / divisor_0;
if (divisor_1 * divisor_0 == m_dim) {
splited.first = divisor_0;
splited.second = divisor_1;
break;
}
if (divisor_1 * divisor_0 == m_dim)
return divisor_0 * batch_dim >= optimal_parallelism_work_amount ? std::make_pair(divisor_0, divisor_1) : splited;
}
OPENVINO_ASSERT(splited.first * splited.second == m_dim, "Incorrect dimension M splitting!");
return splited;
}

std::pair<size_t, size_t> SplitDimensionM::compute_aggressive_heuristic(size_t batch_dim, size_t m_dim, size_t optimal_parallelism_work_amount) {
constexpr size_t min_kernel_m = 32;
std::pair<size_t, size_t> best_result = {1, m_dim};
for (size_t divisor = 2; divisor < std::sqrt(m_dim); ++divisor) {
if (m_dim % divisor != 0)
continue;
if (divisor >= min_kernel_m)
return std::make_pair(m_dim / divisor, divisor);
const size_t m_kernel = m_dim / divisor;
if (m_kernel >= min_kernel_m) {
best_result.first = divisor;
best_result.second = m_kernel;
}
}
if (best_result.first * batch_dim >= optimal_parallelism_work_amount)
return best_result;
return std::make_pair(1, m_dim);
}

bool SplitDimensionM::can_be_optimized(const std::shared_ptr<const ov::Node>& node, size_t concurrency) {
if (!is_supported_matmul(node))
return false;
Expand Down Expand Up @@ -131,16 +141,23 @@ bool SplitDimensionM::split(const ov::Shape& shape, size_t optimal_parallelism_w
if (is_prime_number(m_dim))
return false;

auto is_optimized = [&](size_t batch_dim) {
return batch_dim >= optimal_parallelism_work_amount;
};

// We skip optimization if the current batch is optimal for concurrency
if (is_optimized(batch_dim))
if (batch_dim % optimal_parallelism_work_amount == 0)
return false;

std::tie(batch_m_dim, new_m_dim) = get_splited_dimensions(batch_dim, m_dim, optimal_parallelism_work_amount);
return is_optimized(batch_dim * batch_m_dim);
std::tie(batch_m_dim, new_m_dim) = compute_ideal_cases_heuristic(batch_dim, m_dim, optimal_parallelism_work_amount);
if (batch_m_dim != 1)
return true;

// If M dim is big enough, aggressive heuristic is used for kernel_m minimization.
// For smaller M dim, conservative heuristic is used to preserve old behavour.
const bool big_m_dim = m_dim >= 4000;
if (big_m_dim) {
std::tie(batch_m_dim, new_m_dim) = compute_aggressive_heuristic(batch_dim, m_dim, optimal_parallelism_work_amount);
} else if (batch_dim < optimal_parallelism_work_amount) {
std::tie(batch_m_dim, new_m_dim) = compute_conservative_heuristic(batch_dim, m_dim, optimal_parallelism_work_amount);
}
return batch_m_dim != 1;
}

void SplitDimensionM::reshape_subgraph(const std::shared_ptr<op::Subgraph>& subgraph, const ov::Shape& shape, size_t batch_m_dim, size_t new_m_dim) {
Expand Down
2 changes: 2 additions & 0 deletions src/common/snippets/tests/src/utils/split_dim_m.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,8 @@ const std::vector<SplitDimensionMParams> split_dimension_cases = {
{InputData{25, 50, 40}, ReferenceData{true, 2, 25}},
{InputData{5, 16384, 40}, ReferenceData{true, 8, 2048}},
{InputData{5, 16384, 32}, ReferenceData{true, 32, 512}},
{InputData{48, 4097, 32}, ReferenceData{true, 17, 241}},
{InputData{48, 6600, 32}, ReferenceData{true, 200, 33}},
};

INSTANTIATE_TEST_SUITE_P(smoke_Snippets_SplitDimensionM,
Expand Down

0 comments on commit 78a43ee

Please sign in to comment.