Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Model] Group support for int8/int4 models #461

Draft
wants to merge 3 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 7 additions & 3 deletions src/common/transformer_ctx.h
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,9 @@ struct DecoderContext {
float attFactor;
float epsilon;

// quantization configuration
int groupsize;

// rope scaling parameters
RopeParams *ropeParamsPtr;

Expand Down Expand Up @@ -132,7 +135,7 @@ struct DecoderContext {
DecoderContext(int _layers, int _hiddenSize, int _headSize, int _attHeadNum, int _kvHeadNum, int _imSize, const std::string &act,
float epsilon, int _vocabSize, int _embeddingSize, int _maxPositions, int _maxPosEmbed, int _maxSeqLength,
int _splitIdx, int _splits, MMHelper *mmHelper, void *device = nullptr, int _ppSize = 1, int _ppRank = 0, RopeParams *_ropeParamsPtr = nullptr,
bool _useLogN = true, bool _useNTK = true, int numThreads = 0)
bool _useLogN = true, bool _useNTK = true, int numThreads = 0, int _groupsize = -1)
: layers(_layers)
, hiddenSize(_hiddenSize)
, attHeadSize(_headSize)
Expand All @@ -153,7 +156,8 @@ struct DecoderContext {
, ppRank(_ppRank)
, tpSize(_splits)
, tpRank(_splitIdx)
, epsilon(epsilon) {
, epsilon(epsilon)
, groupsize(_groupsize) {
if (attHeadNum != 0) {
this->attFactor = 1 / sqrtf(attHeadSize);
}
Expand Down Expand Up @@ -325,4 +329,4 @@ struct DecoderContext {
if (this->rawBuffer) xft::dealloc(this->rawBuffer, this->device);
#endif
}
};
};
56 changes: 30 additions & 26 deletions src/layers/attention.h
Original file line number Diff line number Diff line change
Expand Up @@ -123,23 +123,27 @@ class Attention {
float *concatScale = nullptr;
float *concatZero = nullptr;
if constexpr (std::is_same_v<OriWeiT, int8_t> || std::is_same_v<OriWeiT, uint4x2_t>) {
concatScale = (float *)malloc(responsibleCols * sizeof(float));
concatZero = (float *)malloc(responsibleCols * sizeof(float));
memcpy(concatScale, queryScale + this->startQHead * headSize, qResponsibleCols * sizeof(float));
memcpy(concatScale + qResponsibleCols, keyScale + this->startKVHead * headSize,
int qkvStride = (ctx->attHeadNum + ctx->kvHeadNum + ctx->kvHeadNum) * ctx->attHeadSize;
int groups = ctx->groupsize == -1 ? 1 : hiddenSize / ctx->groupsize;
concatScale = (float *)malloc(groups * responsibleCols * sizeof(float));
concatZero = (float *)malloc(groups * responsibleCols * sizeof(float));
for (int i = 0; i < groups; ++i) {
memcpy(concatScale + i * responsibleCols, queryScale + i * qkvStride + this->startQHead * headSize, qResponsibleCols * sizeof(float));
memcpy(concatScale + i * responsibleCols + qResponsibleCols, keyScale + i * qkvStride + this->startKVHead * headSize,
kvResponsibleCols * sizeof(float));
memcpy(concatScale + qResponsibleCols + kvResponsibleCols, valueScale + this->startKVHead * headSize,
memcpy(concatScale + i * responsibleCols + qResponsibleCols + kvResponsibleCols, valueScale + i * qkvStride + this->startKVHead * headSize,
kvResponsibleCols * sizeof(float));
memcpy(concatZero, queryZero + this->startQHead * headSize, qResponsibleCols * sizeof(float));
memcpy(concatZero + qResponsibleCols, keyZero + this->startKVHead * headSize,
memcpy(concatZero + i * responsibleCols, queryZero + i * qkvStride + this->startQHead * headSize, qResponsibleCols * sizeof(float));
memcpy(concatZero + i * responsibleCols + qResponsibleCols, keyZero + i * qkvStride + this->startKVHead * headSize,
kvResponsibleCols * sizeof(float));
memcpy(concatZero + qResponsibleCols + kvResponsibleCols, valueZero + this->startKVHead * headSize,
memcpy(concatZero + i * responsibleCols + qResponsibleCols + kvResponsibleCols, valueZero + i * qkvStride + this->startKVHead * headSize,
kvResponsibleCols * sizeof(float));
}
}

xft::Matrix<WeiT> convertedqkvWeight;
ctx->mmHelper->convertWeight(trans, hiddenSize, responsibleCols, concatBuf, concatScale, concatZero,
convertedqkvWeight, qkvWeightScale, qkvWeightZero, qkvWeightSum);
convertedqkvWeight, qkvWeightScale, qkvWeightZero, qkvWeightSum, ctx->groupsize);

#ifdef XFT_GPU
xft::Matrix<WeiT> qkvWeightT;
Expand Down Expand Up @@ -182,7 +186,7 @@ class Attention {
xft::Matrix<WeiT> convertedOutWeight;
ctx->mmHelper->convertWeight(trans, ctx->attHeadNum * ctx->attHeadSize, hiddenSize, attnOutWeight, attnOutScale,
attnOutZero, this->startQHead * headSize, qResponsibleCols, false, convertedOutWeight,
attnOutputWeightScale, attnOutputWeightZero, attnOutputWeightSum, true);
attnOutputWeightScale, attnOutputWeightZero, attnOutputWeightSum, ctx->groupsize, true);

#ifdef XFT_GPU
xft::Matrix<WeiT> outWeightT;
Expand Down Expand Up @@ -289,11 +293,11 @@ class Attention {
if (qkvBias.Size() == 0) {
ctx->mmHelper->compute(false, imBuffer.Rows(), qkvWeight.Cols(), imBuffer.Cols(), 1.0f, imBuffer.Data(),
imBuffer.Stride(), qkvWeight.Data(), qkvWeightScale.Data(), qkvWeightZero.Data(),
qkvWeightSum.Data(), 0.0f, qkvGroupMatMul.Data(), qkvGroupMatMul.Stride());
qkvWeightSum.Data(), 0.0f, qkvGroupMatMul.Data(), qkvGroupMatMul.Stride(), ctx->groupsize);
} else {
ctx->mmHelper->compute_bias(false, imBuffer.Rows(), qkvWeight.Cols(), imBuffer.Cols(), 1.0f,
imBuffer.Data(), imBuffer.Stride(), qkvWeight.Data(), qkvWeightScale.Data(), qkvWeightZero.Data(),
qkvWeightSum.Data(), 0.0f, qkvGroupMatMul.Data(), qkvGroupMatMul.Stride(), qkvBias.Data());
qkvWeightSum.Data(), 0.0f, qkvGroupMatMul.Data(), qkvGroupMatMul.Stride(), qkvBias.Data(), ctx->groupsize);
}
t2.release();

Expand Down Expand Up @@ -405,26 +409,26 @@ class Attention {
ctx->mmHelper->compute_residential(false, attnSplit.Rows(), attnOutputWeight.Cols(), attnSplit.Cols(),
1.0f, attnSplit.Data(), attnSplit.Stride(), attnOutputWeight.Data(),
attnOutputWeightScale.Data(), attnOutputWeightZero.Data(), attnOutputWeightSum.Data(), 0.0f,
outBuffer.Data(), outBuffer.Stride(), pbias, inputBuffer.Data(), inputBuffer.Stride());
outBuffer.Data(), outBuffer.Stride(), pbias, inputBuffer.Data(), inputBuffer.Stride(), ctx->groupsize);
} else {
float *pbias = attnOutputBias.Data();
if (attnOutputBias.Size() == 0) { pbias = nullptr; }
ctx->mmHelper->compute_resext(false, attnSplit.Rows(), attnOutputWeight.Cols(), attnSplit.Cols(), 1.0f,
attnSplit.Data(), attnSplit.Stride(), attnOutputWeight.Data(), attnOutputWeightScale.Data(),
attnOutputWeightZero.Data(), attnOutputWeightSum.Data(), 0.0f, outBuffer.Data(),
outBuffer.Stride(), pbias, gamma, inputBuffer.Data(), inputBuffer.Stride());
outBuffer.Stride(), pbias, gamma, inputBuffer.Data(), inputBuffer.Stride(), ctx->groupsize);
}
} else {
if (attnOutputBias.Size() == 0) {
ctx->mmHelper->compute(false, attnSplit.Rows(), attnOutputWeight.Cols(), attnSplit.Cols(), 1.0f,
attnSplit.Data(), attnSplit.Stride(), attnOutputWeight.Data(), attnOutputWeightScale.Data(),
attnOutputWeightZero.Data(), attnOutputWeightSum.Data(), 0.0f, outBuffer.Data(),
outBuffer.Stride());
outBuffer.Stride(), ctx->groupsize);
} else {
ctx->mmHelper->compute_bias(false, attnSplit.Rows(), attnOutputWeight.Cols(), attnSplit.Cols(), 1.0f,
attnSplit.Data(), attnSplit.Stride(), attnOutputWeight.Data(), attnOutputWeightScale.Data(),
attnOutputWeightZero.Data(), attnOutputWeightSum.Data(), 0.0f, outBuffer.Data(),
outBuffer.Stride(), attnOutputBias.Data());
outBuffer.Stride(), attnOutputBias.Data(), ctx->groupsize);
}
}
t5.release();
Expand Down Expand Up @@ -495,11 +499,11 @@ class Attention {
if (qkvBias.Size() == 0) {
ctx->mmHelper->compute(false, imBuffer.Rows(), qkvWeight.Cols(), imBuffer.Cols(), 1.0f, imBuffer.Data(),
imBuffer.Stride(), qkvWeight.Data(), qkvWeightScale.Data(), qkvWeightZero.Data(),
qkvWeightSum.Data(), 0.0f, qkvGroupMatMul.Data(), qkvGroupMatMul.Stride());
qkvWeightSum.Data(), 0.0f, qkvGroupMatMul.Data(), qkvGroupMatMul.Stride(), ctx->groupsize);
} else {
ctx->mmHelper->compute_bias(false, imBuffer.Rows(), qkvWeight.Cols(), imBuffer.Cols(), 1.0f,
imBuffer.Data(), imBuffer.Stride(), qkvWeight.Data(), qkvWeightScale.Data(), qkvWeightZero.Data(),
qkvWeightSum.Data(), 0.0f, qkvGroupMatMul.Data(), qkvGroupMatMul.Stride(), qkvBias.Data());
qkvWeightSum.Data(), 0.0f, qkvGroupMatMul.Data(), qkvGroupMatMul.Stride(), qkvBias.Data(), ctx->groupsize);
}
t2.release();

Expand Down Expand Up @@ -588,26 +592,26 @@ class Attention {
ctx->mmHelper->compute_residential(false, attnSplit.Rows(), attnOutputWeight.Cols(), attnSplit.Cols(),
1.0f, attnSplit.Data(), attnSplit.Stride(), attnOutputWeight.Data(),
attnOutputWeightScale.Data(), attnOutputWeightZero.Data(), attnOutputWeightSum.Data(), 0.0f,
outBuffer.Data(), outBuffer.Stride(), pbias, inputBuffer.Data(), inputBuffer.Stride());
outBuffer.Data(), outBuffer.Stride(), pbias, inputBuffer.Data(), inputBuffer.Stride(), ctx->groupsize);
} else {
float *pbias = attnOutputBias.Data();
if (attnOutputBias.Size() == 0) { pbias = nullptr; }
ctx->mmHelper->compute_resext(false, attnSplit.Rows(), attnOutputWeight.Cols(), attnSplit.Cols(), 1.0f,
attnSplit.Data(), attnSplit.Stride(), attnOutputWeight.Data(), attnOutputWeightScale.Data(),
attnOutputWeightZero.Data(), attnOutputWeightSum.Data(), 0.0f, outBuffer.Data(),
outBuffer.Stride(), pbias, gamma, inputBuffer.Data(), inputBuffer.Stride());
outBuffer.Stride(), pbias, gamma, inputBuffer.Data(), inputBuffer.Stride(), ctx->groupsize);
}
} else {
if (attnOutputBias.Size() == 0) {
ctx->mmHelper->compute(false, attnSplit.Rows(), attnOutputWeight.Cols(), attnSplit.Cols(), 1.0f,
attnSplit.Data(), attnSplit.Stride(), attnOutputWeight.Data(), attnOutputWeightScale.Data(),
attnOutputWeightZero.Data(), attnOutputWeightSum.Data(), 0.0f, outBuffer.Data(),
outBuffer.Stride());
outBuffer.Stride(), ctx->groupsize);
} else {
ctx->mmHelper->compute_bias(false, attnSplit.Rows(), attnOutputWeight.Cols(), attnSplit.Cols(), 1.0f,
attnSplit.Data(), attnSplit.Stride(), attnOutputWeight.Data(), attnOutputWeightScale.Data(),
attnOutputWeightZero.Data(), attnOutputWeightSum.Data(), 0.0f, outBuffer.Data(),
outBuffer.Stride(), attnOutputBias.Data());
outBuffer.Stride(), attnOutputBias.Data(), ctx->groupsize);
}
}
t5.release();
Expand Down Expand Up @@ -1183,15 +1187,15 @@ class Attention {

// query, key, value weighs
xft::Matrix<WeiT> qkvWeight;
xft::Vector<float> qkvWeightScale; // if weight is int8
xft::Vector<float> qkvWeightZero; // if weight is int8
xft::Matrix<float> qkvWeightScale; // if weight is int8
xft::Matrix<float> qkvWeightZero; // if weight is int8
xft::Vector<float> qkvWeightSum; // if weight is int8
// query, key, value bias
xft::Vector<float> qkvBias;

xft::Matrix<WeiT> attnOutputWeight;
xft::Vector<float> attnOutputWeightScale; // if weight is int8
xft::Vector<float> attnOutputWeightZero; // if weight is int8
xft::Matrix<float> attnOutputWeightScale; // if weight is int8
xft::Matrix<float> attnOutputWeightZero; // if weight is int8
xft::Vector<float> attnOutputWeightSum; // if weight is int8
xft::Vector<float> attnOutputBias;

Expand Down
Loading
Loading