-
Notifications
You must be signed in to change notification settings - Fork 205
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add CUTLASS-based W4A4 #1515
base: main
Are you sure you want to change the base?
Add CUTLASS-based W4A4 #1515
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/1515
Note: Links to docs will display an error until the docs builds have been completed. This comment was automatically generated by Dr. CI and updates every 15 minutes. |
CUDA code looks fine, of course there are lots of dots to connect remaining on the Python side. The difference from #880 is that this is not mixed data types GEMM, but regular GEMM instead. In that regard, this operator here is maybe easier to be made much more generic, to support other integer and maybe even some floating point input data types. I'm at the moment making some minor changes on this PyTorch operator, and would strongly recommend modelling CUDA code in alike way, as it plain looks nice, and then makes extending the kernel to other datatypes much easier, has extensive checks on operands, etc. Moreover, I think it would make sense at this point to discuss having a single CUTLASS-based kernel for GEMMs with both weights and activations scaled, to be put in the single source file, and to handle both same and mixed data types GEMMs, at least for SM 8.x archs - that would provide for minimum code duplication, and easier maintenance in the future. As far as configurations (tile sizes, number of stages, etc.) concerned, I'd suggest looking here instead in the unit tests, and also comparing performance vs. results reported by CUTLASS profiler for given combination of data types. I believe some sort of tuning configuration on the input shapes is a must in order to achieve a decent performance; but I have to admit that in #880 the tuning is mostly ad-hoc (for comparison, I find this approach more elaborate and meaningful). Thus, I think that coming up with some kind of systematic approach in that regard would be the most beneficial contribution regarding eventual future use of CUTLASS-based kernels in the torchao. (@drisspg: Your comments welcome here.) |
One thing on finding optimal params is that @yifuwang was recently working on finding better configs for an AsyncMM. He did some manual elimination of configs that never seemed to be performant and then fit a simple decision Tree on a big sweep over MKN shapes that could be easily modeled in C++. This is similar to what is done in the RowWise scaling. I think a little flow for this would be helpful I can make an issue to track. No major comments |
Thank you for the feedback.
Though this is nice on paper, I think Triton is the better alternative for other data types (INT8, FP8...). It's more flexible and the autotuner also saves us some headache. Only because of the lack of INT4 support in Triton, we have to use Cutlass, especially for INT4 Tensor cores. Unless we can show that there are cases Triton cannot reach the perf of Cutlass (in the context of this PR, I'm only thinking about INT8 for SM8x, and additionally FP8 for SM89). Having said that, I'm ok with following a certain style/structure. Just point me which one it should be, and I will make modifications accordingly. |
torchao/ops.py
Outdated
Returns: | ||
output: result tensor, in row-major layout. | ||
""" | ||
assert A.dtype == B.dtype == torch.int8 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should add the alignment constraints as well right?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How should I check for data alignment from Python? I guess in C++, I can check by testing divisibility of the memory address? (or perhaps there is a util function somewhere that I'm not aware of...)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hmm I think there is a restriction that k need to be a multiple of 32 right?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Or at least 16 packed int4 s
torchao/csrc/cuda/int4_cutlass.cu
Outdated
using InstructionShape = cutlass::gemm::GemmShape<16, 8, 64>; | ||
// static int const kStages = 3; | ||
using ElementC = int32_t; | ||
using Gemm = cutlass::gemm::device::Gemm< |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do you know if the universal gemm api can be used?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Will look into it. I wrote this quite some time ago...
Attached is a minor patch that will change The structure of CUTLASS-based kernels is typically always the same (see also rowwise scaled MM in PyTorch, mentioned in my previous comment, as well as my CUTLASS-based mixed data types and 2:4 sparsity kernels in PyTorch): from the bottom up, there is always an operator implementation function that contains checking inputs, and then starting a dispatching chain (where run-time data types etc. are translated to compile-time template arguments), that ends up with a typical CUTLASS-based GEMM kernel (that is boilerplate). Also as mentioned in my previous comment, while rowwise scaled MM is very similar in structure, I like how it looks the most - because of clever use of variable template arguments to decrease the clutter, then because of clear extraction of input checks, and configuration selection into separate functions, etc. So I'd suggest we have your C++ code integrated in the way sketched by attached diff, and then also to made minor changes in the C++ code in a way to make it to look closer to rowwise scaled MM implementation. (Of course, operator name and some other stuff on Python side will have to be changed too.) |
As far as performance between various implementations concerned: I'd say in general there are three ways to implement kernels: Triton-based, CUTLASS-based, and custom i.e. from scratch (like Marlin-based kernels). In my experience so far (that was all for Ampere arch), CUTLASS-based kernels are oftentimes somewhat faster than Triton-based kernels, while then for some corner-case input tensor sizes, custom kernels (well, Marlin-based at least) could be significantly faster than CUTLASS-based ones. Furthermore, with Triton there is the least amount of flexibility with upstream changes (they just don't support some input data types, they don't support 2:4 sparsity, etc.), with CUTLASS it's somewhat easier to have changes we may need accepted, while for custom kernels obviously this is not an issue at all. However, Triton kills it when it comes to compilation, in particular regarding fusing GEMM with other kernels, then CUTLASS has some support for compilation but doing fusion is rather cumbersome at the moment, while obviously there is no any kind of compilation support for custom kernels. Then, doing custom kernels would probably lead to lots of code duplication, with CUTLASS this also may be an issue even if to the smaller extent. Etc. - so it's all matter of trade-offs. Still, having in mind auto-tuning and auto-quantization, I belive it still may be good to have as much different kernels in torchao as possible, so I'd expect more CUTLASS-based kernels to be written, besides these W4A8 and W4A4 kernels - and this is the exact reason that, as discussed above, I'd prefer to have as much code shared as possible between these kernels. |
Might be interesting to try out QAT with this setting cc @andrewor14 |
I've made these changes to existing CUTLASS-based W4A8 kernel in #1545, so it should be easier now to eventually include W4A4 functionality there. |
const auto is_sm8x = dprops->major >= 8; | ||
|
||
if (is_sm8x) { | ||
using ElementA = cutlass::int4b_t; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What is the reason to skip dispatch_on_tensor_a_and_tensor_b
step here? I've put this one there exactly to be able to easily choose ElementA
and ElementB
template arguments.
If this file changed to have dispatch_on_tensor_a_and_tensor_b
, then s4s4_linear_cutlass.cu
and s8s4_linear_cutlass.cu
would be pretty much the same, except for different datatypes pair ElementA
/ElementB
, and select_config
would differ. For this reason, my idea with latest changes in s8s4_linear_cutlass.cu
was actually to keep everything in single file. There would be one operator visible to Python side of torchao, instead of s8s4_linear_cutlass
, we'd could name it rowwise_scaled_linear_cutlass
, and it could handle S8/S4, S4/S4, and in the future pretty much any combination of same or mixed data types, both integer and floating point (and also could be extended to handle sm90 and higher). The advantage is that there would be no duplicated code, and it would be easy to extend with new data types, the disadvantage is that large number of templates would be instantiated in single file so the torchao build times would suffer badly (I have no immediate idea, but I hope this could be somehow fixed, through some kind of explicit template instantiation). Please allow me for some time, and I'll try to come up with a patch to show what I mean...
@alexsamardzic Thank you for your prompt feedback. Was going to work on this a little bit more before asking for another round of review. I wasn't sure how much code-sharing you intend to have. Was trying to work with single file, but the biggest blocker is this one if (tensor_a.scalar_type() == at::ScalarType::Char) {
if (tensor_b.scalar_type() == at::ScalarType::Char) {
if (tensor_a.size(1) == 2 * tensor_b.size(1)) { In That's why I decided to move the inner-most template to a separate header file for sharing. Wanted to move I have another question. Currently you have
Side comments. I don't mind code duplication. Realistically speaking, there are not that many useful combinations, so code duplication is not too bad. And you probably already know, for cutlass device-level API, not all combinations work (and when they don't work, there are mysterious errors 😅) |
I realized, while making changes, that your point on dispatching on input/weight tensor data types is right. Namely, PyTorch doesn't provide sub-byte data types, so the only way to differentiate between data type combinations like S4/S4 and S8/S8 is to have this information encoded somehow else. Your approach was to encode it in the name of the operator, and this kind of approach could also solve the issue that I was worried about - that not all templates get instantiated from the single source file. I made some further changes to maximize C++ code reuse. I don't mind code duplication either, but in this case it's (in my opinion) about really cumbersome boilerplate code, that for sanity I'd really much prefer to keep on single place. So I came up with a following patch, to be applied on top of current state of this branch: Patchdiff --git a/benchmarks/benchmark_s8s4_cutlass.py b/benchmarks/benchmark_rowwise_scaled_linear_cutlass.py
similarity index 73%
rename from benchmarks/benchmark_s8s4_cutlass.py
rename to benchmarks/benchmark_rowwise_scaled_linear_cutlass.py
index fbf07eb..00bcb0a 100644
--- a/benchmarks/benchmark_s8s4_cutlass.py
+++ b/benchmarks/benchmark_rowwise_scaled_linear_cutlass.py
@@ -2,7 +2,7 @@ import pandas as pd
import torch
from tqdm import tqdm
-from torchao.ops import s8s4_linear_cutlass
+from torchao.ops import rowwise_scaled_linear_cutlass_s8s4
from torchao.utils import benchmark_torch_function_in_microseconds
@@ -24,8 +24,8 @@ def benchmark(m: int, k: int, n: int):
A_ref, B_ref, A, A_scale, B, B_scale, C = get_problem(m, n, k)
fp16_time = benchmark_torch_function_in_microseconds(torch.matmul, A_ref, B_ref)
- s8s4_linear_cutlass_time = benchmark_torch_function_in_microseconds(
- s8s4_linear_cutlass, A, A_scale, B, B_scale, C
+ rowwise_scaled_linear_cutlass_s8s4_time = benchmark_torch_function_in_microseconds(
+ rowwise_scaled_linear_cutlass_s8s4, A, A_scale, B, B_scale, C
)
return {
@@ -33,8 +33,8 @@ def benchmark(m: int, k: int, n: int):
"k": k,
"n": n,
"fp16_latency (ms)": fp16_time,
- "s8s4_linear_cutlass latency (ms)": s8s4_linear_cutlass_time,
- "speedup (d/s)": fp16_time / s8s4_linear_cutlass_time,
+ "rowwise_scaled_linear_cutlass latency (ms)": rowwise_scaled_linear_cutlass_s8s4_time,
+ "speedup (d/s)": fp16_time / rowwise_scaled_linear_cutlass_s8s4_time,
}
@@ -48,5 +48,5 @@ if __name__ == "__main__":
results.append(benchmark(m, k, n))
df = pd.DataFrame(results)
- df.to_csv("s8s4_linear_cutlass_time_results.csv", index=False)
+ df.to_csv("rowwise_scaled_linear_cutlass_s8s4_time_results.csv", index=False)
print(df.to_markdown(index=False))
diff --git a/test/test_rowwise_scaled_linear_cutlass.py b/test/test_rowwise_scaled_linear_cutlass.py
new file mode 100644
index 0000000..422b100
--- /dev/null
+++ b/test/test_rowwise_scaled_linear_cutlass.py
@@ -0,0 +1,128 @@
+import itertools
+
+import pytest
+import torch
+
+from torchao.ops import (
+ rowwise_scaled_linear_cutlass_s4s4,
+ rowwise_scaled_linear_cutlass_s8s4,
+)
+from torchao.quantization.utils import group_quantize_tensor_symmetric
+from torchao.utils import compute_max_diff
+
+ROWWISE_SCALED_LINEAR_CUTLASS_DTYPE = [torch.float16, torch.bfloat16]
+ROWWISE_SCALED_LINEAR_CUTLASS_BATCH_SIZE = [1, 4, 8, 16, 32, 64]
+ROWWISE_SCALED_LINEAR_CUTLASS_SIZE_MNK = [
+ (2, 512, 128),
+ (3, 2048, 2048),
+ (4, 3584, 640),
+ (13, 8704, 8576),
+ (26, 18944, 1664),
+ (67, 6656, 1408),
+]
+ROWWISE_SCALED_LINEAR_CUTLASS_USE_BIAS = [False, True]
+ROWWISE_SCALED_LINEAR_CUTLASS_TEST_PARAMS = list(
+ itertools.product(
+ ROWWISE_SCALED_LINEAR_CUTLASS_DTYPE,
+ ROWWISE_SCALED_LINEAR_CUTLASS_BATCH_SIZE,
+ ROWWISE_SCALED_LINEAR_CUTLASS_SIZE_MNK,
+ ROWWISE_SCALED_LINEAR_CUTLASS_USE_BIAS,
+ )
+)
+
[email protected](not torch.cuda.is_available(), reason="CUDA not available")
[email protected](
+ "dtype, batch_size, size_mnk, use_bias", ROWWISE_SCALED_LINEAR_CUTLASS_TEST_PARAMS
+)
+def test_rowwise_scaled_linear_cutlass_s4s4(dtype, batch_size, size_mnk, use_bias):
+ size_m, size_n, size_k = size_mnk
+
+ input = torch.randn((batch_size, size_m, size_k), dtype=dtype, device="cuda")
+ weight = torch.rand((size_n, size_k), dtype=dtype, device="cuda")
+ bias = torch.rand((size_n,), dtype=dtype, device="cuda") if use_bias else None
+
+ input_2d = input.view(-1, input.shape[-1])
+ input_2d_s8, input_2d_scales, input_2d_zeros = group_quantize_tensor_symmetric(
+ input_2d, 4, size_k, dtype
+ )
+ assert torch.all(input_2d_zeros == 0)
+ input_s8 = input_2d_s8.reshape(input.shape)
+ input_s4 = ((input_s8[:, :, 1::2] & 0xF) << 4) | (input_s8[:, :, 0::2] & 0xF)
+ input_scales = input_2d_scales.reshape(input.shape[:-1])
+
+ weight_s8, weight_scales, weight_zeros = group_quantize_tensor_symmetric(
+ weight, 4, size_n, dtype
+ )
+ assert torch.all(weight_zeros == 0)
+ weight_s4 = ((weight_s8[:, 1::2] & 0xF) << 4) | (weight_s8[:, 0::2] & 0xF)
+
+ # If torch.nn.functional.linear(input, weight, bias) used as
+ # reference, the error would be too big. The calculation below is
+ # approximately what rowwise_scaled_linear_cutlass kernel is doing
+ # (except that matrix multiplication is over integers there)).
+ size_m_2d = input_2d.shape[0]
+ output_ref = (
+ (input_2d_s8.to(dtype) @ weight_s8.to(dtype).T)
+ * input_2d_scales.view(size_m_2d, 1)
+ * weight_scales.view(1, size_n)
+ )
+ if bias is not None:
+ output_ref += bias
+ output_ref = output_ref.reshape(input.shape[:-1] + (size_n,))
+
+ fn_inputs = (input_s4, input_scales, weight_s4, weight_scales, bias)
+ try:
+ output = rowwise_scaled_linear_cutlass_s4s4(*fn_inputs)
+ except NotImplementedError:
+ pytest.xfail("rowwise_scaled_linear_cutlass() op not implemented")
+
+ max_diff = compute_max_diff(output, output_ref)
+ assert max_diff < 1e-3
+
[email protected](not torch.cuda.is_available(), reason="CUDA not available")
[email protected](
+ "dtype, batch_size, size_mnk, use_bias", ROWWISE_SCALED_LINEAR_CUTLASS_TEST_PARAMS
+)
+def test_rowwise_scaled_linear_cutlass_s8s4(dtype, batch_size, size_mnk, use_bias):
+ size_m, size_n, size_k = size_mnk
+
+ input = torch.randn((batch_size, size_m, size_k), dtype=dtype, device="cuda")
+ weight = torch.rand((size_n, size_k), dtype=dtype, device="cuda")
+ bias = torch.rand((size_n,), dtype=dtype, device="cuda") if use_bias else None
+
+ input_2d = input.view(-1, input.shape[-1])
+ input_2d_s8, input_2d_scales, input_2d_zeros = group_quantize_tensor_symmetric(
+ input_2d, 8, size_k, dtype
+ )
+ assert torch.all(input_2d_zeros == 0)
+ input_s8 = input_2d_s8.reshape(input.shape)
+ input_scales = input_2d_scales.reshape(input.shape[:-1])
+
+ weight_s8, weight_scales, weight_zeros = group_quantize_tensor_symmetric(
+ weight, 4, size_n, dtype
+ )
+ assert torch.all(weight_zeros == 0)
+ weight_s4 = ((weight_s8[:, 1::2] & 0xF) << 4) | (weight_s8[:, 0::2] & 0xF)
+
+ # If torch.nn.functional.linear(input, weight, bias) used as
+ # reference, the error would be too big. The calculation below is
+ # approximately what rowwise_scaled_linear_cutlass kernel is doing
+ # (except that matrix multiplication is over integers there)).
+ size_m_2d = input_2d.shape[0]
+ output_ref = (
+ (input_2d_s8.to(dtype) @ weight_s8.to(dtype).T)
+ * input_2d_scales.view(size_m_2d, 1)
+ * weight_scales.view(1, size_n)
+ )
+ if bias is not None:
+ output_ref += bias
+ output_ref = output_ref.reshape(input.shape[:-1] + (size_n,))
+
+ fn_inputs = (input_s8, input_scales, weight_s4, weight_scales, bias)
+ try:
+ output = rowwise_scaled_linear_cutlass_s8s4(*fn_inputs)
+ except NotImplementedError:
+ pytest.xfail("rowwise_scaled_linear_cutlass() op not implemented")
+
+ max_diff = compute_max_diff(output, output_ref)
+ assert max_diff < 5e-3
diff --git a/test/test_s8s4_linear_cutlass.py b/test/test_s8s4_linear_cutlass.py
deleted file mode 100644
index 6510ada..0000000
--- a/test/test_s8s4_linear_cutlass.py
+++ /dev/null
@@ -1,77 +0,0 @@
-import itertools
-
-import pytest
-import torch
-
-from torchao.ops import s8s4_linear_cutlass
-from torchao.quantization.utils import group_quantize_tensor_symmetric
-from torchao.utils import compute_max_diff
-
-S8S4_LINEAR_CUTLASS_DTYPE = [torch.float16, torch.bfloat16]
-S8S4_LINEAR_CUTLASS_BATCH_SIZE = [1, 4, 8, 16, 32, 64]
-S8S4_LINEAR_CUTLASS_SIZE_MNK = [
- (2, 512, 128),
- (3, 2048, 2048),
- (4, 3584, 640),
- (13, 8704, 8576),
- (26, 18944, 1664),
- (67, 6656, 1408),
-]
-S8S4_LINEAR_CUTLASS_USE_BIAS = [False, True]
-S8S4_LINEAR_CUTLASS_TEST_PARAMS = list(
- itertools.product(
- S8S4_LINEAR_CUTLASS_DTYPE,
- S8S4_LINEAR_CUTLASS_BATCH_SIZE,
- S8S4_LINEAR_CUTLASS_SIZE_MNK,
- S8S4_LINEAR_CUTLASS_USE_BIAS,
- )
-)
-
-
[email protected](not torch.cuda.is_available(), reason="CUDA not available")
[email protected](
- "dtype, batch_size, size_mnk, use_bias", S8S4_LINEAR_CUTLASS_TEST_PARAMS
-)
-def test_s8s4_linear_cutlass(dtype, batch_size, size_mnk, use_bias):
- size_m, size_n, size_k = size_mnk
-
- input = torch.randn((batch_size, size_m, size_k), dtype=dtype, device="cuda")
- weight = torch.rand((size_n, size_k), dtype=dtype, device="cuda")
- bias = torch.rand((size_n,), dtype=dtype, device="cuda") if use_bias else None
-
- input_2d = input.view(-1, input.shape[-1])
- input_2d_s8, input_2d_scales, input_2d_zeros = group_quantize_tensor_symmetric(
- input_2d, 8, size_k, dtype
- )
- assert torch.all(input_2d_zeros == 0)
- input_s8 = input_2d_s8.reshape(input.shape)
- input_scales = input_2d_scales.reshape(input.shape[:-1])
-
- weight_s8, weight_scales, weight_zeros = group_quantize_tensor_symmetric(
- weight, 4, size_n, dtype
- )
- assert torch.all(weight_zeros == 0)
- weight_s4 = ((weight_s8[:, 1::2] & 0xF) << 4) | (weight_s8[:, 0::2] & 0xF)
-
- # If torch.nn.functional.linear(input, weight, bias) used as
- # reference, the error would be too big. The calculation below is
- # approximately what s8s4_linear_cutlass kernel is doing (except
- # that matrrix multiplication is over integers there)).
- size_m_2d = input_2d.shape[0]
- output_ref = (
- (input_2d_s8.to(dtype) @ weight_s8.to(dtype).T)
- * input_2d_scales.view(size_m_2d, 1)
- * weight_scales.view(1, size_n)
- )
- if bias is not None:
- output_ref += bias
- output_ref = output_ref.reshape(input.shape[:-1] + (size_n,))
-
- fn_inputs = (input_s8, input_scales, weight_s4, weight_scales, bias)
- try:
- output = s8s4_linear_cutlass(*fn_inputs)
- except NotImplementedError:
- pytest.xfail("s8s4_linear_cutlass() op not implemented")
-
- max_diff = compute_max_diff(output, output_ref)
- assert max_diff < 5e-3
diff --git a/torchao/csrc/cuda/int4_cutlass.cu b/torchao/csrc/cuda/int4_cutlass.cu
deleted file mode 100644
index 452abcc..0000000
--- a/torchao/csrc/cuda/int4_cutlass.cu
+++ /dev/null
@@ -1,231 +0,0 @@
-#include <torch/extension.h>
-#include <ATen/cuda/CUDAContext.h>
-
-// copied from s8s4_linear_cutlass.cu
-#if defined(TORCHAO_USE_CUTLASS) && !defined(_WIN32) && \
- defined(CUDA_VERSION) && (CUDA_VERSION >= 11080)
-#define BUILD_INT4_MM_CUTLASS
-#endif
-
-#if defined(BUILD_INT4_MM_CUTLASS)
-#include "cutlass/cutlass.h"
-#include "cutlass/gemm/device/gemm_universal.h"
-#include "cutlass/gemm/device/gemm.h"
-#include "cutlass/epilogue/threadblock/fusion/visitors.hpp"
-#include "cutlass/gemm/kernel/default_gemm_universal_with_visitor.h"
-#include "cutlass/gemm/device/gemm_universal_adapter.h"
-
-#define CUTLASS_STATUS_CHECK(status) \
- { \
- TORCH_CHECK(status == cutlass::Status::kSuccess, \
- __func__, " : Got CUTLASS error: ", \
- cutlassGetStatusString(status)); \
- }
-#endif
-
-namespace torchao {
-
-#if defined(BUILD_INT4_MM_CUTLASS)
-// define common params
-using ElementA = cutlass::int4b_t;
-using ElementB = cutlass::int4b_t;
-using ElementAccumulator = int32_t;
-using OpClass = cutlass::arch::OpClassTensorOp;
-using ArchTag = cutlass::arch::Sm80;
-
-// how many elements to load at a time -> load 128-bit = 32 x 4-bit
-constexpr int AlignmentA = 128 / cutlass::sizeof_bits<ElementA>::value;
-constexpr int AlignmentB = 128 / cutlass::sizeof_bits<ElementB>::value;
-#endif
-
-// we will do input checks in python. A and B are stored as int8
-torch::Tensor int4_mm_cutlass(torch::Tensor A, torch::Tensor B) {
-#if defined(BUILD_INT4_MM_CUTLASS)
- int M = A.size(0);
- int K = A.size(1) * 2;
- int N = B.size(1);
- torch::Tensor C = torch::empty({M, N}, A.options().dtype(torch::kInt32));
-
- // some configs for int4 mma
- // https://github.com/NVIDIA/cutlass/blob/v3.5.1/test/unit/gemm/device/gemm_s4t_s4n_s32t_tensor_op_s32_sm80.cu
- // using default config. this can be tuned.
- using ThreadblockShape = cutlass::gemm::GemmShape<128, 256, 128>;
- using WarpShape = cutlass::gemm::GemmShape<64, 64, 128>;
- using InstructionShape = cutlass::gemm::GemmShape<16, 8, 64>;
- // static int const kStages = 3;
- using ElementC = int32_t;
- using Gemm = cutlass::gemm::device::Gemm<
- ElementA, cutlass::layout::RowMajor, // A matrix
- ElementB, cutlass::layout::ColumnMajor, // B matrix
- ElementC, cutlass::layout::RowMajor, // C matrix
- ElementAccumulator, OpClass, ArchTag,
- ThreadblockShape, WarpShape, InstructionShape
- >;
- Gemm::Arguments args {
- {M, N, K},
- {reinterpret_cast<ElementA *>(A.data_ptr<int8_t>()), K},
- {reinterpret_cast<ElementB *>(B.data_ptr<int8_t>()), K},
- {C.data_ptr<ElementC>(), N},
- {C.data_ptr<ElementC>(), N},
- {1, 0} // epilogue
- };
- Gemm gemm_op;
- CUTLASS_STATUS_CHECK(gemm_op(args));
- return C;
-#else
- TORCH_CHECK_NOT_IMPLEMENTED(false, __func__);
- return at::Tensor{};
-#endif
-}
-
-template<
- typename ElementC,
- typename ThreadblockShape,
- typename WarpShape,
- typename InstructionShape,
- int numStages>
-void scaled_int4_mm_cutlass_dispatch(torch::Tensor A, torch::Tensor B, torch::Tensor row_scale, torch::Tensor col_scale, torch::Tensor C) {
- // problem shape
- int M = A.size(0);
- int K = A.size(1) * 2;
- int N = B.size(1);
-
- constexpr int AlignmentC = 128 / cutlass::sizeof_bits<ElementC>::value; // 8 for BF16/FP16
- using ElementEpilogue = float;
- constexpr int numEpilogueStages = 1;
-
- // build epilogue visitor tree
- using OutputTileThreadMap = cutlass::epilogue::threadblock::OutputTileThreadLayout<
- ThreadblockShape, WarpShape, ElementC, AlignmentC, numEpilogueStages
- >;
-
- using Accum = cutlass::epilogue::threadblock::VisitorAccFetch;
- constexpr auto RoundMode = cutlass::FloatRoundStyle::round_to_nearest;
- using Multiply = cutlass::epilogue::threadblock::VisitorCompute<
- cutlass::multiplies, ElementEpilogue, ElementEpilogue, RoundMode
- >;
-
- // (1, N)
- using ColScale = cutlass::epilogue::threadblock::VisitorRowBroadcast<
- OutputTileThreadMap, ElementC,
- cute::Stride<cute::_0, cute::_1, int32_t> // MNL
- >;
- using EVTCompute0 = cutlass::epilogue::threadblock::Sm80EVT<Multiply, Accum, ColScale>;
-
- // (M, 1)
- using RowScale = cutlass::epilogue::threadblock::VisitorColBroadcast<
- OutputTileThreadMap, ElementC,
- cute::Stride<cute::_1, cute::_0, int32_t> // MNL
- >;
- using EVTCompute1 = cutlass::epilogue::threadblock::Sm80EVT<Multiply, EVTCompute0, RowScale>;
-
- using Output = cutlass::epilogue::threadblock::VisitorAuxStore<
- OutputTileThreadMap, ElementC, RoundMode,
- cute::Stride<int64_t, cute::_1, int64_t> // MNL
- >;
- using EVTOutput = cutlass::epilogue::threadblock::Sm80EVT<Output, EVTCompute1>;
-
- using EVTKernel = typename cutlass::gemm::kernel::DefaultGemmWithVisitor<
- ElementA, cutlass::layout::RowMajor, cutlass::ComplexTransform::kNone, AlignmentA,
- ElementB, cutlass::layout::ColumnMajor, cutlass::ComplexTransform::kNone, AlignmentB,
- ElementC, cutlass::layout::RowMajor, AlignmentC,
- ElementAccumulator, ElementEpilogue, OpClass, ArchTag,
- ThreadblockShape, WarpShape, InstructionShape,
- EVTOutput,
- cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>,
- numStages,
- cutlass::arch::OpMultiplyAddSaturate, // OpMultiplyAdd does not work
- numEpilogueStages
- >::GemmKernel;
- using DeviceGemm = cutlass::gemm::device::GemmUniversalAdapter<EVTKernel>;
-
- // col_scale, row_scale, and C must have the same dtype
- const ElementA *A_ptr = reinterpret_cast<ElementA *>(A.data_ptr<int8_t>());
- const ElementB *B_ptr = reinterpret_cast<ElementB *>(B.data_ptr<int8_t>());
- const ElementC *col_scale_ptr = reinterpret_cast<ElementC *>(col_scale.data_ptr());
- const ElementC *row_scale_ptr = reinterpret_cast<ElementC *>(row_scale.data_ptr());
- ElementC *C_ptr = reinterpret_cast<ElementC *>(C.data_ptr());
-
- typename EVTOutput::Arguments callback_args{
- {
- {
- {}, // Accum
- {col_scale_ptr, ElementC(0), {cute::_0{}, cute::_1{}, int32_t(N)}}, // ColScale
- {} // Multiply
- }, // EVTCompute0
- {row_scale_ptr, ElementC(0), {cute::_1{}, cute::_0{}, int32_t(M)}}, // RowScale
- {} // Multiply
- }, // EVTCompute1
- {C_ptr, {int64_t{N}, cute::_1{}, int64_t{M*N}}} // EVTOutput
- };
-
- typename DeviceGemm::Arguments args(
- cutlass::gemm::GemmUniversalMode::kGemm,
- cutlass::gemm::GemmCoord{M, N, K},
- 1, // batch_split
- callback_args,
- A_ptr, B_ptr, nullptr, nullptr, // unsued C_ptr and D_ptr
- M * K, N * K, 0, 0, // batch_stride A, B, C, D
- K, K, 0, 0 // stride A, B, C, D
- );
-
- DeviceGemm gemm_op;
- auto stream = at::cuda::getCurrentCUDAStream();
- CUTLASS_STATUS_CHECK(gemm_op.can_implement(args));
- CUTLASS_STATUS_CHECK(gemm_op(args, nullptr, stream));
-}
-
-// we will do input checks in python. A and B are stored as int8
-// this function is based on the following cutlass example
-// https://github.com/NVIDIA/cutlass/blob/main/examples/47_ampere_gemm_universal_streamk/ampere_gemm_universal_streamk_broadcast.cu
-// also with the help of emitted code from cutlass Python
-torch::Tensor scaled_int4_mm_cutlass(torch::Tensor A, torch::Tensor B, torch::Tensor row_scale, torch::Tensor col_scale) {
-#if defined(BUILD_INT4_MM_CUTLASS)
- int M = A.size(0);
- int N = B.size(1);
- torch::Tensor C = torch::empty({M, N}, row_scale.options());
-
- // some configs for int4 mma
- // https://github.com/NVIDIA/cutlass/blob/v3.5.1/test/unit/gemm/device/gemm_s4t_s4n_s32t_tensor_op_s32_sm80.cu
- // using default config. this can be tuned.
- using ThreadblockShape = cutlass::gemm::GemmShape<128, 256, 128>;
- using WarpShape = cutlass::gemm::GemmShape<64, 64, 128>;
- using InstructionShape = cutlass::gemm::GemmShape<16, 8, 64>;
- constexpr int numStages = 3;
-
- AT_DISPATCH_SWITCH(
- row_scale.scalar_type(),
- "scaled_int4_mm_cutlass",
- AT_DISPATCH_CASE(
- torch::ScalarType::Half,
- [&]() {
- using ElementC = cutlass::half_t;
- scaled_int4_mm_cutlass_dispatch<
- ElementC, ThreadblockShape, WarpShape, InstructionShape, numStages>(
- A, B, row_scale, col_scale, C);
- }
- )
- AT_DISPATCH_CASE(
- torch::ScalarType::BFloat16,
- [&]() {
- using ElementC = cutlass::bfloat16_t;
- scaled_int4_mm_cutlass_dispatch<
- ElementC, ThreadblockShape, WarpShape, InstructionShape, numStages>(
- A, B, row_scale, col_scale, C);
- }
- )
- );
-
- return C;
-#else
- TORCH_CHECK_NOT_IMPLEMENTED(false, __func__);
- return at::Tensor{};
-#endif
-}
-
-TORCH_LIBRARY_IMPL(torchao, CUDA, m) {
- m.impl("torchao::int4_mm_cutlass", &int4_mm_cutlass);
- m.impl("torchao::scaled_int4_mm_cutlass", &scaled_int4_mm_cutlass);
-}
-
-} // namespace torchao
diff --git a/torchao/csrc/cuda/rowwise_scaled_linear_cutlass/rowwise_scaled_linear_cutlass.cuh b/torchao/csrc/cuda/rowwise_scaled_linear_cutlass/rowwise_scaled_linear_cutlass.cuh
new file mode 100644
index 0000000..d1969bc
--- /dev/null
+++ b/torchao/csrc/cuda/rowwise_scaled_linear_cutlass/rowwise_scaled_linear_cutlass.cuh
@@ -0,0 +1,578 @@
+#pragma once
+
+#include <ATen/ATen.h>
+#include <ATen/core/Tensor.h>
+#include <ATen/cuda/CUDAUtils.h>
+#include <c10/util/Exception.h>
+
+#if defined(TORCHAO_USE_CUTLASS) && !defined(_WIN32) && \
+ defined(CUDA_VERSION) && (CUDA_VERSION >= 11080)
+#define BUILD_ROWWISE_SCALED_LINEAR_CUTLASS
+#endif
+
+#if defined(BUILD_ROWWISE_SCALED_LINEAR_CUTLASS)
+#include <cuda_runtime.h>
+#include <cutlass/cutlass.h>
+#include <cutlass/gemm/device/gemm_universal.h>
+#include <cutlass/epilogue/threadblock/fusion/visitors.hpp>
+#include <cutlass/gemm/kernel/default_gemm_universal_with_visitor.h>
+#include <cutlass/gemm/device/gemm_universal_adapter.h>
+
+#define CUTLASS_STATUS_CHECK(status) \
+ { \
+ TORCH_CHECK(status == cutlass::Status::kSuccess, \
+ __func__, " : Got CUTLASS error: ", \
+ cutlassGetStatusString(status)); \
+ }
+#endif
+
+namespace torchao {
+
+#if defined(BUILD_ROWWISE_SCALED_LINEAR_CUTLASS)
+template<
+ typename ThreadblockShape,
+ typename WarpShape,
+ typename InstructionShape,
+ typename ThreadblockSwizzle,
+ int NumStages,
+ typename ElementA,
+ typename ElementB,
+ typename ElementOutput,
+ typename ElementC,
+ typename UseTensorC,
+ typename ElementAScale,
+ typename ElementBScale>
+void rowwise_scaled_linear_kernel_cutlass_sm8x(
+ const at::Tensor& tensor_a, const at::Tensor& tensor_a_scale,
+ const at::Tensor& tensor_b, const at::Tensor& tensor_b_scale,
+ const at::Tensor& tensor_c, at::Tensor& tensor_d) {
+ static_assert((cutlass::sizeof_bits<ElementA>::value >= 8 ||
+ 8 % cutlass::sizeof_bits<ElementA>::value == 0) &&
+ (cutlass::sizeof_bits<ElementB>::value >= 8 ||
+ 8 % cutlass::sizeof_bits<ElementB>::value == 0));
+
+ using SmArch = cutlass::arch::Sm80;
+
+ using LayoutA = cutlass::layout::RowMajor;
+ using LayoutB = cutlass::layout::ColumnMajor;
+ using LayoutOutput = cutlass::layout::RowMajor;
+
+ using ElementAccumulator = int32_t;
+ using Operator =
+ std::conditional_t<std::is_same<ElementA, ElementB>::value,
+ cutlass::arch::OpMultiplyAddSaturate,
+ cutlass::arch::OpMultiplyAddMixedInputUpcast>;
+
+ using ElementEpilogue = float;
+
+ constexpr auto NumEVTEpilogueStages = 1;
+
+ const int m = tensor_a.size(0);
+ const int n = tensor_b.size(0);
+ int k = tensor_a.size(1);
+ if constexpr (cutlass::sizeof_bits<ElementA>::value < 8) {
+ k *= 8 % cutlass::sizeof_bits<ElementA>::value;
+ }
+
+ constexpr int AlignmentA = 128 / cutlass::sizeof_bits<ElementA>::value;
+ constexpr int AlignmentAScale =
+ 128 / cutlass::sizeof_bits<ElementAScale>::value;
+ constexpr int AlignmentB = 128 / cutlass::sizeof_bits<ElementB>::value;
+ constexpr int AlignmentBScale =
+ 128 / cutlass::sizeof_bits<ElementBScale>::value;
+ constexpr int AlignmentC = 128 / cutlass::sizeof_bits<ElementC>::value;
+ constexpr int AlignmentOutput =
+ 128 / cutlass::sizeof_bits<ElementOutput>::value;
+
+ // Check for current CUTLASS limitations w.r.t. alignments.
+ TORCH_CHECK(k % AlignmentA == 0,
+ __func__, " : Number of columns of tensor A must be divisible ",
+ "by ", AlignmentA);
+ TORCH_CHECK(k % AlignmentB == 0,
+ __func__, " : Number of columns of tensor B must be divisible ",
+ "by ", AlignmentB);
+ TORCH_CHECK(n % AlignmentC == 0,
+ __func__, " : Number of columns of tensor C must be divisible ",
+ "by ", AlignmentC);
+
+ using TensorAScaleTileThreadMap =
+ cutlass::epilogue::threadblock::OutputTileThreadLayout<
+ ThreadblockShape,
+ WarpShape,
+ ElementAScale,
+ AlignmentAScale,
+ NumEVTEpilogueStages>;
+ using TensorBScaleTileThreadMap =
+ cutlass::epilogue::threadblock::OutputTileThreadLayout<
+ ThreadblockShape,
+ WarpShape,
+ ElementBScale,
+ AlignmentBScale,
+ NumEVTEpilogueStages>;
+ using TensorCTileThreadMap =
+ cutlass::epilogue::threadblock::OutputTileThreadLayout<
+ ThreadblockShape,
+ WarpShape,
+ ElementC,
+ AlignmentC,
+ NumEVTEpilogueStages>;
+ using OutputTileThreadMap =
+ cutlass::epilogue::threadblock::OutputTileThreadLayout<
+ ThreadblockShape,
+ WarpShape,
+ ElementOutput,
+ AlignmentOutput,
+ NumEVTEpilogueStages>;
+
+ using Accum = cutlass::epilogue::threadblock::VisitorAccFetch;
+
+ using TensorAScale =
+ cutlass::epilogue::threadblock::VisitorColBroadcast<
+ TensorAScaleTileThreadMap,
+ ElementAScale,
+ cute::Stride<cute::_1, cute::_0, int64_t>>;
+ using TensorAScaleArguments = typename TensorAScale::Arguments;
+
+ using TensorBScale =
+ cutlass::epilogue::threadblock::VisitorRowBroadcast<
+ TensorBScaleTileThreadMap,
+ ElementBScale,
+ cute::Stride<cute::_0, cute::_1, int64_t>>;
+ using TensorBScaleArguments = typename TensorBScale::Arguments;
+
+ using TensorCScalar =
+ cutlass::epilogue::threadblock::VisitorScalarBroadcast<ElementC>;
+ using TensorCTensor =
+ cutlass::epilogue::threadblock::VisitorRowBroadcast<
+ TensorCTileThreadMap,
+ ElementC,
+ cute::Stride<cute::_0, cute::_1, int64_t>>;
+ using TensorC =
+ std::conditional_t<UseTensorC::value, TensorCTensor, TensorCScalar>;
+ using TensorCArguments = typename TensorC::Arguments;
+
+ using ApplyAScale = cutlass::epilogue::threadblock::VisitorCompute<
+ cutlass::multiplies, ElementEpilogue, ElementEpilogue,
+ cutlass::FloatRoundStyle::round_to_nearest
+ >;
+ using EVTApplyAScale = cutlass::epilogue::threadblock::Sm80EVT<
+ ApplyAScale,
+ Accum,
+ TensorAScale>;
+
+ using ApplyBScale = cutlass::epilogue::threadblock::VisitorCompute<
+ cutlass::multiplies, ElementEpilogue, ElementEpilogue,
+ cutlass::FloatRoundStyle::round_to_nearest
+ >;
+ using EVTApplyBScale = cutlass::epilogue::threadblock::Sm80EVT<
+ ApplyBScale,
+ EVTApplyAScale,
+ TensorBScale>;
+
+ using ApplySum = cutlass::epilogue::threadblock::VisitorCompute<
+ cutlass::plus, ElementEpilogue, ElementEpilogue,
+ cutlass::FloatRoundStyle::round_to_nearest
+ >;
+ using EVTApplySum = cutlass::epilogue::threadblock::Sm80EVT<
+ ApplySum,
+ EVTApplyBScale,
+ TensorC>;
+
+ using Output = cutlass::epilogue::threadblock::VisitorAuxStore<
+ OutputTileThreadMap, ElementOutput,
+ cutlass::FloatRoundStyle::round_to_nearest,
+ cute::Stride<int64_t, cute::_1, int64_t> // StrideMNL
+ >;
+
+ using EVTOutput = cutlass::epilogue::threadblock::Sm80EVT<
+ Output,
+ EVTApplySum>;
+
+ using EVTKernel =
+ typename cutlass::gemm::kernel::DefaultGemmWithVisitor<
+ ElementA, LayoutA, cutlass::ComplexTransform::kNone, AlignmentA,
+ ElementB, LayoutB, cutlass::ComplexTransform::kNone, AlignmentB,
+ ElementOutput, LayoutOutput, AlignmentOutput,
+ ElementAccumulator,
+ ElementEpilogue,
+ cutlass::arch::OpClassTensorOp,
+ SmArch,
+ ThreadblockShape,
+ WarpShape,
+ InstructionShape,
+ EVTOutput,
+ ThreadblockSwizzle,
+ NumStages,
+ Operator,
+ NumEVTEpilogueStages
+ >::GemmKernel;
+
+ using Gemm = cutlass::gemm::device::GemmUniversalAdapter<EVTKernel>;
+
+ cutlass::gemm::GemmCoord problem_size(m, n, k);
+ constexpr auto SplitKFactor = 1;
+
+ TensorAScaleArguments tensor_a_scale_arguments{
+ (ElementAScale*)tensor_a_scale.data_ptr(),
+ ElementAScale(1),
+ {cute::_1{}, cute::_0{}, problem_size.m()}
+ };
+ TensorBScaleArguments tensor_b_scale_arguments{
+ (ElementBScale*)tensor_b_scale.data_ptr(),
+ ElementBScale(1),
+ {cute::_0{}, cute::_1{}, problem_size.n()}
+ };
+ TensorCArguments tensor_c_arguments{
+ [&]() -> TensorCArguments {
+ if constexpr (UseTensorC::value) {
+ return {(ElementC*)tensor_c.data_ptr(),
+ ElementC(0),
+ {cute::_0{}, cute::_1{}, problem_size.n()}};
+ } else {
+ return {ElementC(0)};
+ }
+ }()
+ };
+ typename Output::Arguments output_arguments{
+ (ElementOutput*)tensor_d.data_ptr(),
+ {problem_size.n(), cute::_1{}, problem_size.mn().product()}
+ };
+ typename EVTOutput::Arguments callback_arguments{
+ {
+ {
+ {
+ {}, // Accum
+ tensor_a_scale_arguments, // TensorAScale
+ {} // ApplyAScale
+ }, // EVTApplyAScale
+ tensor_b_scale_arguments, // TensorBScale
+ {}, // ApplyBScale
+ }, // EVTApplyBScale
+ tensor_c_arguments, // TensorC
+ {} // ApplySum
+ }, // EVTApplySum
+ output_arguments // Output
+ }; // EVTOutput
+
+ typename Gemm::Arguments arguments(
+ cutlass::gemm::GemmUniversalMode::kGemm,
+ problem_size,
+ SplitKFactor,
+ callback_arguments, // arguments of EVT callbacks
+ (ElementA*)tensor_a.data_ptr(),
+ (ElementB*)tensor_b.data_ptr(),
+ nullptr, // ptr C (unused)
+ nullptr, // ptr D (unused)
+ problem_size.mk().product(), // batch stride A
+ problem_size.nk().product(), // batch stride B
+ 0, // batch stride C (unused)
+ 0, // batch stride D (unused)
+ problem_size.k(), // stride A
+ problem_size.k(), // stride B
+ 0, // stride C (unused)
+ 0 // stride D (unused)
+ );
+
+ Gemm gemm_op;
+
+ cutlass::Status status;
+
+ // Verify that GEMM operation with given arguments can be performed
+ // by CUTLASS.
+ status = gemm_op.can_implement(arguments);
+ CUTLASS_STATUS_CHECK(status);
+
+ // Allocate workspace for CUTLASS mixed datatypes GEMM kernel.
+ const auto workspace_size = Gemm::get_workspace_size(arguments);
+ auto workspace = tensor_a.new_empty({(int64_t)workspace_size},
+ at::TensorOptions().dtype(at::kByte));
+
+ // Initialize CUTLASS mixed datatypes GEMM object.
+ status = gemm_op.initialize(arguments, workspace.data_ptr(),
+ at::cuda::getCurrentCUDAStream());
+ CUTLASS_STATUS_CHECK(status);
+
+ // Perform mixed datatypes GEMM operation.
+ status = gemm_op.run(at::cuda::getCurrentCUDAStream());
+ CUTLASS_STATUS_CHECK(status);
+
+ C10_CUDA_KERNEL_LAUNCH_CHECK();
+}
+
+template<typename ElementA, typename ElementB, typename... Types>
+static void select_config(
+ const at::Tensor& tensor_a, const at::Tensor& tensor_a_scale,
+ const at::Tensor& tensor_b, const at::Tensor& tensor_b_scale,
+ const at::Tensor& tensor_c, at::Tensor& tensor_d) {
+ const auto dprops = at::cuda::getCurrentDeviceProperties();
+ const auto is_sm8x = dprops->major == 8;
+
+ if (is_sm8x) {
+ if constexpr (std::is_same<ElementA, cutlass::int4b_t>::value &&
+ std::is_same<ElementB, cutlass::int4b_t>::value) {
+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 256, 128>;
+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 128>;
+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 64>;
+ using ThreadblockSwizzle =
+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>;
+ constexpr auto NumStages = 3;
+ rowwise_scaled_linear_kernel_cutlass_sm8x<
+ ThreadblockShape, WarpShape, InstructionShape, ThreadblockSwizzle,
+ NumStages, ElementA, ElementB, Types...>(
+ tensor_a, tensor_a_scale, tensor_b, tensor_b_scale, tensor_c,
+ tensor_d);
+ return;
+ } else if constexpr (std::is_same<ElementA, int8_t>::value &&
+ std::is_same<ElementB, cutlass::int4b_t>::value) {
+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 32>;
+ using ThreadblockSwizzle =
+ cutlass::gemm::threadblock::ThreadblockSwizzleStreamK;
+
+ // A minimal heuristic to improve performance for small number
+ // of inputs cases.
+ if (tensor_a.size(0) <= 16) {
+ using ThreadblockShape = cutlass::gemm::GemmShape<16, 128, 128>;
+ using WarpShape = cutlass::gemm::GemmShape<16, 32, 128>;
+ constexpr auto NumStages = 6;
+ rowwise_scaled_linear_kernel_cutlass_sm8x<
+ ThreadblockShape, WarpShape, InstructionShape, ThreadblockSwizzle,
+ NumStages, ElementA, ElementB, Types...>(
+ tensor_a, tensor_a_scale, tensor_b, tensor_b_scale, tensor_c,
+ tensor_d);
+ } else if (tensor_a.size(0) <= 32) {
+ using ThreadblockShape = cutlass::gemm::GemmShape<32, 128, 128>;
+ using WarpShape = cutlass::gemm::GemmShape<32, 32, 128>;
+ constexpr auto NumStages = 5;
+ rowwise_scaled_linear_kernel_cutlass_sm8x<
+ ThreadblockShape, WarpShape, InstructionShape, ThreadblockSwizzle,
+ NumStages, ElementA, ElementB, Types...>(
+ tensor_a, tensor_a_scale, tensor_b, tensor_b_scale, tensor_c,
+ tensor_d);
+ } else {
+ using ThreadblockShape = cutlass::gemm::GemmShape<64, 128, 128>;
+ using WarpShape = cutlass::gemm::GemmShape<64, 32, 128>;
+ constexpr auto NumStages = 4;
+ rowwise_scaled_linear_kernel_cutlass_sm8x<
+ ThreadblockShape, WarpShape, InstructionShape, ThreadblockSwizzle,
+ NumStages, ElementA, ElementB, Types...>(
+ tensor_a, tensor_a_scale, tensor_b, tensor_b_scale, tensor_c,
+ tensor_d);
+ }
+ return;
+ }
+ }
+
+ TORCH_CHECK(false,
+ __func__, " : Operator not supported on SM", dprops->major, ".",
+ dprops->minor, " for given operands");
+}
+
+template<
+ typename ElementA,
+ typename ElementB,
+ typename ElementOutput,
+ typename... Types>
+static void
+dispatch_on_tensor_c(
+ const at::Tensor& tensor_a, const at::Tensor& tensor_a_scale,
+ const at::Tensor& tensor_b, const at::Tensor& tensor_b_scale,
+ const at::Tensor& tensor_c, at::Tensor& tensor_d) {
+ if (tensor_c.numel() == 0) {
+ using ElementC = ElementOutput;
+ using UseTensorC = std::false_type;
+ select_config<
+ ElementA, ElementB, ElementOutput, ElementC, UseTensorC, Types...>(
+ tensor_a, tensor_a_scale, tensor_b, tensor_b_scale, tensor_c,
+ tensor_d);
+ return;
+ }
+
+ using UseTensorC = std::true_type;
+ if (tensor_c.scalar_type() == at::ScalarType::Half) {
+ using ElementC = cutlass::half_t;
+ select_config<
+ ElementA, ElementB, ElementOutput, ElementC, UseTensorC, Types...>(
+ tensor_a, tensor_a_scale, tensor_b, tensor_b_scale, tensor_c,
+ tensor_d);
+ return;
+ } else if (tensor_c.scalar_type() == at::ScalarType::BFloat16) {
+ using ElementC = cutlass::bfloat16_t;
+ select_config<
+ ElementA, ElementB, ElementOutput, ElementC, UseTensorC, Types...>(
+ tensor_a, tensor_a_scale, tensor_b, tensor_b_scale, tensor_c,
+ tensor_d);
+ return;
+ }
+
+ TORCH_CHECK(false,
+ __func__, " : Operator not supported for datatype ",
+ tensor_c.scalar_type(), " for addend");
+}
+
+template<typename ElementA, typename ElementB>
+static void
+dispatch_on_tensor_a_scale_and_tensor_b_scale(
+ const at::Tensor& tensor_a, const at::Tensor& tensor_a_scale,
+ const at::Tensor& tensor_b, const at::Tensor& tensor_b_scale,
+ const at::Tensor& tensor_c, at::Tensor& tensor_d) {
+ TORCH_CHECK(tensor_d.scalar_type() == tensor_a_scale.scalar_type(),
+ __func__, " : Operator not supported for output datatype ",
+ tensor_d.scalar_type(), " as it's different from the first ",
+ " operand scale datatype ", tensor_a_scale.scalar_type());
+
+ if (tensor_a_scale.scalar_type() == at::ScalarType::Half &&
+ tensor_b_scale.scalar_type() == at::ScalarType::Half) {
+ using ElementAScale = cutlass::half_t;
+ using ElementBScale = cutlass::half_t;
+ using ElementOutput = cutlass::half_t;
+ dispatch_on_tensor_c<ElementA, ElementB, ElementOutput, ElementAScale,
+ ElementBScale>(
+ tensor_a, tensor_a_scale, tensor_b, tensor_b_scale, tensor_c, tensor_d);
+ return;
+ } else if (tensor_a_scale.scalar_type() == at::ScalarType::BFloat16 &&
+ tensor_b_scale.scalar_type() == at::ScalarType::BFloat16) {
+ using ElementAScale = cutlass::bfloat16_t;
+ using ElementBScale = cutlass::bfloat16_t;
+ using ElementOutput = cutlass::bfloat16_t;
+ dispatch_on_tensor_c<ElementA, ElementB, ElementOutput, ElementAScale,
+ ElementBScale>(
+ tensor_a, tensor_a_scale, tensor_b, tensor_b_scale, tensor_c, tensor_d);
+ return;
+ }
+
+ TORCH_CHECK(false,
+ __func__, " : Operator not supported for combination of data ",
+ "types ", tensor_a_scale.scalar_type(),
+ " for first operand scale and ", tensor_b_scale.scalar_type(),
+ " for second operand scale");
+}
+
+template<typename ElementA, typename ElementB>
+void
+rowwise_scaled_linear_cutlass_check_inputs(
+ const at::Tensor& xq, const at::Tensor& x_scale, const at::Tensor& wq,
+ const at::Tensor& w_scale, const at::Tensor& bias) {
+ // Validate layouts of arguments.
+ TORCH_CHECK(xq.dim() >= 2,
+ __func__, " : Expected xq argument to be 2D or "
+ "higher-dimensional tensor, got ", xq.dim(), " dims");
+ TORCH_CHECK(xq.layout() == at::Layout::Strided,
+ __func__, " : Expected xq argument to be strided, got layout ",
+ xq.layout());
+ TORCH_CHECK(x_scale.dim() == xq.dim() - 1,
+ __func__, " : Expected xq scale argument to be ", xq.dim() - 1,
+ "D tensor, got ", x_scale.dim(), " dims");
+ TORCH_CHECK(x_scale.layout() == at::Layout::Strided,
+ __func__, " : Expected xq scale argument to be strided, got "
+ "layout ", x_scale.layout());
+ TORCH_CHECK(wq.dim() == 2,
+ __func__, " : Expected wq argument to be 2D tensor, got ",
+ wq.dim(), " dims");
+ TORCH_CHECK(wq.layout() == at::Layout::Strided,
+ __func__, " : Expected wq argument to be strided, got layout ",
+ wq.layout());
+ TORCH_CHECK(w_scale.dim() == 1 || w_scale.dim() == 2,
+ __func__, " : Expected wq scale argument to be 1D or 2D tensor, ",
+ "got ", w_scale.dim(), " dims");
+ TORCH_CHECK(w_scale.layout() == at::Layout::Strided,
+ __func__, " : Expected wq scale argument to be strided, got "
+ "layout ", w_scale.layout());
+ if (bias.numel() > 0) {
+ TORCH_CHECK(bias.dim() == 1,
+ __func__, " : Expected bias argument to be 1D tensor, got ",
+ bias.dim(), " dims");
+ TORCH_CHECK(bias.layout() == at::Layout::Strided,
+ __func__, " : Expected bias argument to be strided, got ",
+ "layout ", bias.layout());
+ }
+
+ // Validate sizes of arguments.
+ const auto xq_sizes = xq.sizes().vec();
+ TORCH_CHECK(xq_sizes.back() == wq.size(1) ||
+ xq_sizes.back() == 2 * wq.size(1),
+ __func__, " : Expected xq argument to have ", wq.size(1), " or ",
+ 2 * wq.size(1), " columns, but got ", xq_sizes.back());
+ const auto x_scale_sizes = x_scale.sizes().vec();
+ for (auto i = 0; i < x_scale_sizes.size(); ++i)
+ TORCH_CHECK(x_scale_sizes[i] == xq_sizes[i],
+ __func__, " : Expected xq scale argument size at position ",
+ i, " to be ", xq_sizes[i], ", but got ", x_scale_sizes[i]);
+ TORCH_CHECK(w_scale.numel() == wq.size(0),
+ __func__, " : Expected wq scale argument to have ", wq.size(0),
+ " elements, got ", w_scale.numel(), " elements");
+ if (bias.numel() > 0) {
+ TORCH_CHECK(bias.numel() == wq.size(0),
+ __func__, " : Expected bias argument to have ", wq.size(0),
+ " elements, got ", bias.numel(), " elements");
+ }
+
+ // Validate strides of arguments.
+ const auto xq_strides = xq.strides();
+ TORCH_CHECK(xq_strides[xq_strides.size() - 1] == 1,
+ __func__, " : Expected xq argument in row-major layout");
+ auto xq_stride_expected = xq_strides[xq_strides.size() - 2];
+ for (int i = xq_strides.size() - 3; i >= 0; --i) {
+ xq_stride_expected *= xq_sizes[i + 1];
+ TORCH_CHECK(xq_strides[i] == xq_stride_expected,
+ __func__, " : Expected xq argument in row-major layout");
+ }
+ TORCH_CHECK(x_scale.is_contiguous(),
+ __func__, " : Expected xq scale argument to be contiguous");
+ const auto wq_strides = wq.strides();
+ TORCH_CHECK(wq_strides[0] >= 1 && wq_strides[1] == 1,
+ __func__, " : Expected wq argument in row-major layout");
+ TORCH_CHECK(w_scale.is_contiguous(),
+ __func__, " : Expected wq scale argument to be contiguous");
+ if (bias.numel() > 0) {
+ const auto bias_strides = bias.strides();
+ TORCH_CHECK(bias_strides[0] == 1,
+ __func__, " : Expected bias argument to be contiguous");
+ }
+}
+#endif
+
+// Perform linear operation, using corresponding CUTLASS datatypes
+// GEMM kernel, to given arguments - result produced is:
+// (tensor_a * tensor_a_scale) @ (tensor_b * tensor_b_scale).T + tensor_c
+//
+// Notes: The "tensor_a" and "tensor_b" are expected to be 2D tensors.
+// The "tensor_a_scale" tensor is expected to be a vector, of size
+// equal to number of rows of "tensor_a" tensor. The "tensor_b_scale"
+// tensor is expected to be a vector, of size equal to number of rows
+// of "tensor_b" tensor. The "tensor_c" tensor is expected to be a
+// vector, of size equal to number of rows of "tensor_b" tensor.
+template <typename ElementA, typename ElementB>
+at::Tensor
+rowwise_scaled_linear_cutlass(
+ const at::Tensor& xq, const at::Tensor& x_scale, const at::Tensor& wq,
+ const at::Tensor& w_scale, const at::Tensor& bias) {
+#if defined(BUILD_ROWWISE_SCALED_LINEAR_CUTLASS)
+ // Check inputs.
+ rowwise_scaled_linear_cutlass_check_inputs<ElementA, ElementB>(
+ xq, x_scale, wq, w_scale, bias);
+
+ // Squash the input tensors as appropriate.
+ const auto xq_sizes = xq.sizes().vec();
+ const auto xq_2d = xq.reshape({-1, xq_sizes.back()});
+ const auto x_scale_1d = x_scale.reshape({-1});
+ const auto w_scale_1d = w_scale.reshape({-1});
+
+ // Create result tensor.
+ at::Tensor result =
+ x_scale.new_empty({xq_2d.size(0), wq.size(0)});
+
+ // Dispatch to appropriate kernel template.
+ dispatch_on_tensor_a_scale_and_tensor_b_scale<ElementA, ElementB>(
+ xq_2d, x_scale_1d, wq, w_scale_1d, bias, result);
+
+ // Reshape and return result tensor.
+ auto result_sizes = xq_sizes;
+ result_sizes.back() = wq.size(0);
+ return result.reshape(result_sizes);
+#else
+ TORCH_CHECK_NOT_IMPLEMENTED(false, __func__);
+ return at::Tensor{};
+#endif
+}
+
+} // namespace torchao
diff --git a/torchao/csrc/cuda/rowwise_scaled_linear_cutlass/rowwise_scaled_linear_cutlass_s4s4.cu b/torchao/csrc/cuda/rowwise_scaled_linear_cutlass/rowwise_scaled_linear_cutlass_s4s4.cu
new file mode 100644
index 0000000..9a64b2b
--- /dev/null
+++ b/torchao/csrc/cuda/rowwise_scaled_linear_cutlass/rowwise_scaled_linear_cutlass_s4s4.cu
@@ -0,0 +1,28 @@
+#include <torch/extension.h>
+
+#include "rowwise_scaled_linear_cutlass.cuh"
+
+namespace torchao {
+
+at::Tensor
+rowwise_scaled_linear_cutlass_s4s4(
+ const at::Tensor& xq, const at::Tensor& x_scale, const at::Tensor& wq,
+ const at::Tensor& w_scale, const at::Tensor& bias) {
+ // Validate input datatypes.
+ TORCH_CHECK(xq.dtype() == at::kChar && wq.dtype() == at::kChar,
+ __func__, " : The input datatypes combination ", xq.dtype(),
+ " for xq and ", wq.dtype(), " for wq is not supported");
+
+ // Dispatch to appropriate kernel template.
+ using ElementA = cutlass::int4b_t;
+ using ElementB = cutlass::int4b_t;
+ return rowwise_scaled_linear_cutlass<ElementA, ElementB>(
+ xq, x_scale, wq, w_scale, bias);
+}
+
+TORCH_LIBRARY_IMPL(torchao, CUDA, m) {
+ m.impl("torchao::rowwise_scaled_linear_cutlass_s4s4",
+ &rowwise_scaled_linear_cutlass_s4s4);
+}
+
+} // namespace torchao
diff --git a/torchao/csrc/cuda/rowwise_scaled_linear_cutlass/rowwise_scaled_linear_cutlass_s8s4.cu b/torchao/csrc/cuda/rowwise_scaled_linear_cutlass/rowwise_scaled_linear_cutlass_s8s4.cu
new file mode 100644
index 0000000..752c557
--- /dev/null
+++ b/torchao/csrc/cuda/rowwise_scaled_linear_cutlass/rowwise_scaled_linear_cutlass_s8s4.cu
@@ -0,0 +1,28 @@
+#include <torch/extension.h>
+
+#include "rowwise_scaled_linear_cutlass.cuh"
+
+namespace torchao {
+
+at::Tensor
+rowwise_scaled_linear_cutlass_s8s4(
+ const at::Tensor& xq, const at::Tensor& x_scale, const at::Tensor& wq,
+ const at::Tensor& w_scale, const at::Tensor& bias) {
+ // Validate input datatypes.
+ TORCH_CHECK(xq.dtype() == at::kChar && wq.dtype() == at::kChar,
+ __func__, " : The input datatypes combination ", xq.dtype(),
+ " for xq and ", wq.dtype(), " for wq is not supported");
+
+ // Dispatch to appropriate kernel template.
+ using ElementA = int8_t;
+ using ElementB = cutlass::int4b_t;
+ return rowwise_scaled_linear_cutlass<ElementA, ElementB>(
+ xq, x_scale, wq, w_scale, bias);
+}
+
+TORCH_LIBRARY_IMPL(torchao, CUDA, m) {
+ m.impl("torchao::rowwise_scaled_linear_cutlass_s8s4",
+ &rowwise_scaled_linear_cutlass_s8s4);
+}
+
+} // namespace torchao
diff --git a/torchao/csrc/cuda/s8s4_linear_cutlass/s4s4_linear_cutlass.cu b/torchao/csrc/cuda/s8s4_linear_cutlass/s4s4_linear_cutlass.cu
deleted file mode 100644
index 8faf13c..0000000
--- a/torchao/csrc/cuda/s8s4_linear_cutlass/s4s4_linear_cutlass.cu
+++ /dev/null
@@ -1,268 +0,0 @@
-#include <torch/extension.h>
-
-#include <ATen/ATen.h>
-#include <ATen/core/Tensor.h>
-#include <ATen/cuda/CUDAUtils.h>
-#include <c10/util/Exception.h>
-
-#if defined(TORCHAO_USE_CUTLASS) && !defined(_WIN32) && \
- defined(CUDA_VERSION) && (CUDA_VERSION >= 11080)
-#define BUILD_S4S4_LINEAR_CUTLASS
-#endif
-
-#if defined(BUILD_S4S4_LINEAR_CUTLASS)
-#include "scaled_linear.h"
-#include <cuda_runtime.h>
-#include <cutlass/cutlass.h>
-#include <cutlass/gemm/device/gemm_universal.h>
-#endif
-
-namespace torchao {
-
-#if defined(BUILD_S4S4_LINEAR_CUTLASS)
-
-template<typename... Types>
-static void select_config(
- const at::Tensor& tensor_a, const at::Tensor& tensor_a_scale,
- const at::Tensor& tensor_b, const at::Tensor& tensor_b_scale,
- const at::Tensor& tensor_c, at::Tensor& tensor_d) {
- const auto dprops = at::cuda::getCurrentDeviceProperties();
- const auto is_sm8x = dprops->major >= 8;
-
- if (is_sm8x) {
- using ElementA = cutlass::int4b_t;
- using ElementB = cutlass::int4b_t;
- using ElementAccumulator = int32_t;
-
- using ThreadblockShape = cutlass::gemm::GemmShape<128, 256, 128>;
- using WarpShape = cutlass::gemm::GemmShape<64, 64, 128>;
- using InstructionShape = cutlass::gemm::GemmShape<16, 8, 64>;
- constexpr auto NumStages = 3;
- using Operator = cutlass::arch::OpMultiplyAddSaturate;
- // using Operator = cutlass::arch::OpMultiplyAddMixedInputUpcast; // this does not work
- using ThreadblockSwizzle =
- cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>;
-
- scaled_linear_kernel_cutlass_sm8x<
- ThreadblockShape, WarpShape, InstructionShape, NumStages,
- ThreadblockSwizzle, ElementA, ElementB, ElementAccumulator, Operator,
- Types...>(
- tensor_a, tensor_a_scale, tensor_b, tensor_b_scale, tensor_c,
- tensor_d);
- return;
- }
-
- TORCH_CHECK(false,
- __func__, " : Operator not supported on SM", dprops->major, ".",
- dprops->minor, " for given operands");
-}
-
-template<typename ElementAScale, typename ElementBScale, typename ElementOutput>
-static void
-dispatch_on_tensor_c(
- const at::Tensor& tensor_a, const at::Tensor& tensor_a_scale,
- const at::Tensor& tensor_b, const at::Tensor& tensor_b_scale,
- const at::Tensor& tensor_c, at::Tensor& tensor_d) {
- if (tensor_c.numel() == 0) {
- using ElementC = ElementOutput;
- using UseTensorC = std::false_type;
- select_config<
- ElementAScale, ElementBScale, ElementC, UseTensorC, ElementOutput>(
- tensor_a, tensor_a_scale, tensor_b, tensor_b_scale, tensor_c,
- tensor_d);
- return;
- }
-
- using UseTensorC = std::true_type;
- if (tensor_c.scalar_type() == at::ScalarType::Half) {
- using ElementC = cutlass::half_t;
- select_config<
- ElementAScale, ElementBScale, ElementC, UseTensorC, ElementOutput>(
- tensor_a, tensor_a_scale, tensor_b, tensor_b_scale, tensor_c,
- tensor_d);
- return;
- } else if (tensor_c.scalar_type() == at::ScalarType::BFloat16) {
- using ElementC = cutlass::bfloat16_t;
- select_config<
- ElementAScale, ElementBScale, ElementC, UseTensorC, ElementOutput>(
- tensor_a, tensor_a_scale, tensor_b, tensor_b_scale, tensor_c,
- tensor_d);
- return;
- }
-
- TORCH_CHECK(false,
- __func__, " : Operator not supported for datatype ",
- tensor_c.scalar_type(), " for addend");
-}
-
-static void
-dispatch_on_tensor_a_scale_and_tensor_b_scale(
- const at::Tensor& tensor_a, const at::Tensor& tensor_a_scale,
- const at::Tensor& tensor_b, const at::Tensor& tensor_b_scale,
- const at::Tensor& tensor_c, at::Tensor& tensor_d) {
- TORCH_CHECK(tensor_d.scalar_type() == tensor_a_scale.scalar_type(),
- __func__, " : Operator not supported for output datatype ",
- tensor_d.scalar_type(), " as it's different from the first ",
- " operand scale datatype ", tensor_a_scale.scalar_type());
-
- if (tensor_a_scale.scalar_type() == at::ScalarType::Half &&
- tensor_b_scale.scalar_type() == at::ScalarType::Half) {
- using ElementAScale = cutlass::half_t;
- using ElementBScale = cutlass::half_t;
- using ElementOutput = cutlass::half_t;
- dispatch_on_tensor_c<ElementAScale, ElementBScale, ElementOutput>(
- tensor_a, tensor_a_scale, tensor_b, tensor_b_scale, tensor_c, tensor_d);
- return;
- } else if (tensor_a_scale.scalar_type() == at::ScalarType::BFloat16 &&
- tensor_b_scale.scalar_type() == at::ScalarType::BFloat16) {
- using ElementAScale = cutlass::bfloat16_t;
- using ElementBScale = cutlass::bfloat16_t;
- using ElementOutput = cutlass::bfloat16_t;
- dispatch_on_tensor_c<ElementAScale, ElementBScale, ElementOutput>(
- tensor_a, tensor_a_scale, tensor_b, tensor_b_scale, tensor_c, tensor_d);
- return;
- }
-
- TORCH_CHECK(false,
- __func__, " : Operator not supported for combination of data ",
- "types ", tensor_a_scale.scalar_type(),
- " for first operand scale and ", tensor_b_scale.scalar_type(),
- " for second operand scale");
-}
-
-static void
-check_inputs(
- const at::Tensor& xq, const at::Tensor& x_scale, const at::Tensor& wq,
- const at::Tensor& w_scale, const at::Tensor& bias) {
- // Validate layouts of arguments.
- TORCH_CHECK(xq.dim() >= 2,
- __func__, " : Expected xq argument to be 2D or "
- "higher-dimensional tensor, got ", xq.dim(), " dims");
- TORCH_CHECK(xq.layout() == at::Layout::Strided,
- __func__, " : Expected xq argument to be strided, got layout ",
- xq.layout());
- TORCH_CHECK(x_scale.dim() == xq.dim() - 1,
- __func__, " : Expected xq scale argument to be ", xq.dim() - 1,
- "D tensor, got ", x_scale.dim(), " dims");
- TORCH_CHECK(x_scale.layout() == at::Layout::Strided,
- __func__, " : Expected xq scale argument to be strided, got "
- "layout ", x_scale.layout());
- TORCH_CHECK(wq.dim() == 2,
- __func__, " : Expected wq argument to be 2D tensor, got ",
- wq.dim(), " dims");
- TORCH_CHECK(wq.layout() == at::Layout::Strided,
- __func__, " : Expected wq argument to be strided, got layout ",
- wq.layout());
- TORCH_CHECK(w_scale.dim() == 1 || w_scale.dim() == 2,
- __func__, " : Expected wq scale argument to be 1D or 2D tensor, ",
- "got ", w_scale.dim(), " dims");
- TORCH_CHECK(w_scale.layout() == at::Layout::Strided,
- __func__, " : Expected wq scale argument to be strided, got "
- "layout ", w_scale.layout());
- if (bias.numel() > 0) {
- TORCH_CHECK(bias.dim() == 1,
- __func__, " : Expected bias argument to be 1D tensor, got ",
- bias.dim(), " dims");
- TORCH_CHECK(bias.layout() == at::Layout::Strided,
- __func__, " : Expected bias argument to be strided, got ",
- "layout ", bias.layout());
- }
-
- // Validate sizes of arguments.
- const auto xq_sizes = xq.sizes().vec();
- TORCH_CHECK(xq_sizes.back() == wq.size(1),
- __func__, " : Expected xq argument to have ", wq.size(1),
- " columns, but got ", xq_sizes.back());
- const auto x_scale_sizes = x_scale.sizes().vec();
- for (auto i = 0; i < x_scale_sizes.size(); ++i)
- TORCH_CHECK(x_scale_sizes[i] == xq_sizes[i],
- __func__, " : Expected xq scale argument size at position ",
- i, " to be ", xq_sizes[i], ", but got ", x_scale_sizes[i]);
- TORCH_CHECK(w_scale.numel() == wq.size(0),
- __func__, " : Expected wq scale argument to have ", wq.size(0),
- " elements, got ", w_scale.numel(), " elements");
- if (bias.numel() > 0) {
- TORCH_CHECK(bias.numel() == wq.size(0),
- __func__, " : Expected bias argument to have ", wq.size(0),
- " elements, got ", bias.numel(), " elements");
- }
-
- // Validate strides of arguments.
- const auto xq_strides = xq.strides();
- TORCH_CHECK(xq_strides[xq_strides.size() - 1] == 1,
- __func__, " : Expected xq argument in row-major layout");
- auto xq_stride_expected = xq_strides[xq_strides.size() - 2];
- for (int i = xq_strides.size() - 3; i >= 0; --i) {
- xq_stride_expected *= xq_sizes[i + 1];
- TORCH_CHECK(xq_strides[i] == xq_stride_expected,
- __func__, " : Expected xq argument in row-major layout");
- }
- TORCH_CHECK(x_scale.is_contiguous(),
- __func__, " : Expected xq scale argument to be contiguous");
- const auto wq_strides = wq.strides();
- TORCH_CHECK(wq_strides[0] >= 1 && wq_strides[1] == 1,
- __func__, " : Expected wq argument in row-major layout");
- TORCH_CHECK(w_scale.is_contiguous(),
- __func__, " : Expected wq scale argument to be contiguous");
- if (bias.numel() > 0) {
- const auto bias_strides = bias.strides();
- TORCH_CHECK(bias_strides[0] == 1,
- __func__, " : Expected bias argument to be contiguous");
- }
-}
-#endif
-
-// Perform linear operation, using corresponding CUTLASS mixed
-// data-types GEMM kernel, to given arguments:
-// result = (xq * x_scale) @ (wq * w_scale).T + bias
-// Notes: The "x_scale" tensor is expected to be a vector, of size
-// equal to number of rows of "xq" tensor. The "w_scale" tensor is
-// expected to be a vector, of size equal to number of rows of "wq"
-// tensor. The "bias" tensor is expected to be a vector, of size equal
-// to number of rows of "wq" tensor.
-at::Tensor
-s4s4_linear_cutlass(
- const at::Tensor& xq, const at::Tensor& x_scale, const at::Tensor& wq,
- const at::Tensor& w_scale, const at::Tensor& bias) {
-#if defined(BUILD_S4S4_LINEAR_CUTLASS)
- // Check inputs.
- check_inputs(xq, x_scale, wq, w_scale, bias);
-
- // Squash the input tensors as appropriate.
- const auto xq_sizes = xq.sizes().vec();
- const auto xq_2d = xq.reshape({-1, xq_sizes.back()});
- const auto x_scale_sizes = x_scale.sizes().vec();
- const auto x_scale_1d = x_scale.reshape({-1});
- const auto w_scale_1d = w_scale.reshape({-1});
-
- // Introduce alias names for arguments, according to the CUTLASS
- // naming conventions.
- const auto& tensor_a = xq_2d;
- const auto& tensor_a_scale = x_scale_1d;
- const auto& tensor_b = wq;
- const auto& tensor_b_scale = w_scale_1d;
- const auto& tensor_c = bias;
-
- // Create output tensor.
- at::Tensor tensor_d =
- tensor_a_scale.new_empty({tensor_a.size(0), tensor_b.size(0)});
-
- // Dispatch to appropriate kernel template.
- dispatch_on_tensor_a_scale_and_tensor_b_scale(
- tensor_a, tensor_a_scale, tensor_b, tensor_b_scale, tensor_c, tensor_d);
-
- // Reshape and return output tensor.
- auto tensor_d_sizes = xq_sizes;
- tensor_d_sizes.back() = wq.size(0);
- return tensor_d.reshape(tensor_d_sizes);
-#else
- TORCH_CHECK_NOT_IMPLEMENTED(false, __func__);
- return at::Tensor{};
-#endif
-}
-
-TORCH_LIBRARY_IMPL(torchao, CUDA, m) {
- m.impl("torchao::s4s4_linear_cutlass", &s4s4_linear_cutlass);
-}
-
-} // namespace torchao
diff --git a/torchao/csrc/cuda/s8s4_linear_cutlass/s8s4_linear_cutlass.cu b/torchao/csrc/cuda/s8s4_linear_cutlass/s8s4_linear_cutlass.cu
deleted file mode 100644
index 53eaf53..0000000
--- a/torchao/csrc/cuda/s8s4_linear_cutlass/s8s4_linear_cutlass.cu
+++ /dev/null
@@ -1,315 +0,0 @@
-#include <torch/extension.h>
-
-#include <ATen/ATen.h>
-#include <ATen/core/Tensor.h>
-#include <ATen/cuda/CUDAUtils.h>
-#include <c10/util/Exception.h>
-
-#if defined(TORCHAO_USE_CUTLASS) && !defined(_WIN32) && \
- defined(CUDA_VERSION) && (CUDA_VERSION >= 11080)
-#define BUILD_S8S4_LINEAR_CUTLASS
-#endif
-
-#if defined(BUILD_S8S4_LINEAR_CUTLASS)
-#include "scaled_linear.h"
-#include <cuda_runtime.h>
-#include <cutlass/cutlass.h>
-#include <cutlass/gemm/device/gemm_universal.h>
-#endif
-
-namespace torchao {
-
-#if defined(BUILD_S8S4_LINEAR_CUTLASS)
-
-template<typename ElementA, typename ElementB, typename... Types>
-static void select_config(
- const at::Tensor& tensor_a, const at::Tensor& tensor_a_scale,
- const at::Tensor& tensor_b, const at::Tensor& tensor_b_scale,
- const at::Tensor& tensor_c, at::Tensor& tensor_d) {
- const auto dprops = at::cuda::getCurrentDeviceProperties();
- const auto is_sm8x = dprops->major == 8;
-
- if (is_sm8x) {
- if constexpr (std::is_same<ElementA, int8_t>::value &&
- std::is_same<ElementB, cutlass::int4b_t>::value) {
- using ThreadblockSwizzle =
- cutlass::gemm::threadblock::ThreadblockSwizzleStreamK;
- using InstructionShape = cutlass::gemm::GemmShape<16, 8, 32>;
-
- // A minimal heuristic to improve performance for small number
- // of inputs cases.
- if (tensor_a.size(0) <= 16) {
- using ThreadblockShape = cutlass::gemm::GemmShape<16, 128, 128>;
- using WarpShape = cutlass::gemm::GemmShape<16, 32, 128>;
- constexpr auto NumStages = 6;
- scaled_linear_kernel_cutlass_sm8x<
- ThreadblockShape, WarpShape, InstructionShape, NumStages,
- ThreadblockSwizzle, ElementA, ElementB, Types...>(
- tensor_a, tensor_a_scale, tensor_b, tensor_b_scale, tensor_c,
- tensor_d);
- } else if (tensor_a.size(0) <= 32) {
- using ThreadblockShape = cutlass::gemm::GemmShape<32, 128, 128>;
- using WarpShape = cutlass::gemm::GemmShape<32, 32, 128>;
- constexpr auto NumStages = 5;
- scaled_linear_kernel_cutlass_sm8x<
- ThreadblockShape, WarpShape, InstructionShape, NumStages,
- ThreadblockSwizzle, ElementA, ElementB, Types...>(
- tensor_a, tensor_a_scale, tensor_b, tensor_b_scale, tensor_c,
- tensor_d);
- } else {
- using ThreadblockShape = cutlass::gemm::GemmShape<64, 128, 128>;
- using WarpShape = cutlass::gemm::GemmShape<64, 32, 128>;
- constexpr auto NumStages = 4;
- scaled_linear_kernel_cutlass_sm8x<
- ThreadblockShape, WarpShape, InstructionShape, NumStages,
- ThreadblockSwizzle, ElementA, ElementB, Types...>(
- tensor_a, tensor_a_scale, tensor_b, tensor_b_scale, tensor_c,
- tensor_d);
- }
- return;
- }
- }
-
- TORCH_CHECK(false,
- __func__, " : Operator not supported on SM", dprops->major, ".",
- dprops->minor, " for given operands");
-}
-
-template<typename... Types>
-static void
-dispatch_on_tensor_a_and_tensor_b(
- const at::Tensor& tensor_a, const at::Tensor& tensor_a_scale,
- const at::Tensor& tensor_b, const at::Tensor& tensor_b_scale,
- const at::Tensor& tensor_c, at::Tensor& tensor_d) {
- if (tensor_a.scalar_type() == at::ScalarType::Char) {
- if (tensor_b.scalar_type() == at::ScalarType::Char) {
- if (tensor_a.size(1) == 2 * tensor_b.size(1)) {
- using ElementA = int8_t;
- using ElementB = cutlass::int4b_t;
- using ElementAccumulator = int32_t;
- using Operator = cutlass::arch::OpMultiplyAddMixedInputUpcast;
- select_config<
- ElementA, ElementB, ElementAccumulator, Operator, Types...>(
- tensor_a, tensor_a_scale, tensor_b, tensor_b_scale, tensor_c,
- tensor_d);
- }
- return;
- }
- }
-
- TORCH_CHECK(false,
- __func__, " : Operator not supported for combination of data ",
- "types ", tensor_a.scalar_type(), " for first operand and ",
- tensor_b.scalar_type(), " for second operand");
-}
-
-
-template<typename ElementAScale, typename ElementBScale, typename ElementOutput>
-static void
-dispatch_on_tensor_c(
- const at::Tensor& tensor_a, const at::Tensor& tensor_a_scale,
- const at::Tensor& tensor_b, const at::Tensor& tensor_b_scale,
- const at::Tensor& tensor_c, at::Tensor& tensor_d) {
- if (tensor_c.numel() == 0) {
- using ElementC = ElementOutput;
- using UseTensorC = std::false_type;
- dispatch_on_tensor_a_and_tensor_b<
- ElementAScale, ElementBScale, ElementC, UseTensorC, ElementOutput>(
- tensor_a, tensor_a_scale, tensor_b, tensor_b_scale, tensor_c,
- tensor_d);
- return;
- }
-
- using UseTensorC = std::true_type;
- if (tensor_c.scalar_type() == at::ScalarType::Half) {
- using ElementC = cutlass::half_t;
- dispatch_on_tensor_a_and_tensor_b<
- ElementAScale, ElementBScale, ElementC, UseTensorC, ElementOutput>(
- tensor_a, tensor_a_scale, tensor_b, tensor_b_scale, tensor_c,
- tensor_d);
- return;
- } else if (tensor_c.scalar_type() == at::ScalarType::BFloat16) {
- using ElementC = cutlass::bfloat16_t;
- dispatch_on_tensor_a_and_tensor_b<
- ElementAScale, ElementBScale, ElementC, UseTensorC, ElementOutput>(
- tensor_a, tensor_a_scale, tensor_b, tensor_b_scale, tensor_c,
- tensor_d);
- return;
- }
-
- TORCH_CHECK(false,
- __func__, " : Operator not supported for datatype ",
- tensor_c.scalar_type(), " for addend");
-}
-
-static void
-dispatch_on_tensor_a_scale_and_tensor_b_scale(
- const at::Tensor& tensor_a, const at::Tensor& tensor_a_scale,
- const at::Tensor& tensor_b, const at::Tensor& tensor_b_scale,
- const at::Tensor& tensor_c, at::Tensor& tensor_d) {
- TORCH_CHECK(tensor_d.scalar_type() == tensor_a_scale.scalar_type(),
- __func__, " : Operator not supported for output datatype ",
- tensor_d.scalar_type(), " as it's different from the first ",
- " operand scale datatype ", tensor_a_scale.scalar_type());
-
- if (tensor_a_scale.scalar_type() == at::ScalarType::Half &&
- tensor_b_scale.scalar_type() == at::ScalarType::Half) {
- using ElementAScale = cutlass::half_t;
- using ElementBScale = cutlass::half_t;
- using ElementOutput = cutlass::half_t;
- dispatch_on_tensor_c<ElementAScale, ElementBScale, ElementOutput>(
- tensor_a, tensor_a_scale, tensor_b, tensor_b_scale, tensor_c, tensor_d);
- return;
- } else if (tensor_a_scale.scalar_type() == at::ScalarType::BFloat16 &&
- tensor_b_scale.scalar_type() == at::ScalarType::BFloat16) {
- using ElementAScale = cutlass::bfloat16_t;
- using ElementBScale = cutlass::bfloat16_t;
- using ElementOutput = cutlass::bfloat16_t;
- dispatch_on_tensor_c<ElementAScale, ElementBScale, ElementOutput>(
- tensor_a, tensor_a_scale, tensor_b, tensor_b_scale, tensor_c, tensor_d);
- return;
- }
-
- TORCH_CHECK(false,
- __func__, " : Operator not supported for combination of data ",
- "types ", tensor_a_scale.scalar_type(),
- " for first operand scale and ", tensor_b_scale.scalar_type(),
- " for second operand scale");
-}
-
-static void
-check_inputs(
- const at::Tensor& xq, const at::Tensor& x_scale, const at::Tensor& wq,
- const at::Tensor& w_scale, const at::Tensor& bias) {
- // Validate layouts of arguments.
- TORCH_CHECK(xq.dim() >= 2,
- __func__, " : Expected xq argument to be 2D or "
- "higher-dimensional tensor, got ", xq.dim(), " dims");
- TORCH_CHECK(xq.layout() == at::Layout::Strided,
- __func__, " : Expected xq argument to be strided, got layout ",
- xq.layout());
- TORCH_CHECK(x_scale.dim() == xq.dim() - 1,
- __func__, " : Expected xq scale argument to be ", xq.dim() - 1,
- "D tensor, got ", x_scale.dim(), " dims");
- TORCH_CHECK(x_scale.layout() == at::Layout::Strided,
- __func__, " : Expected xq scale argument to be strided, got "
- "layout ", x_scale.layout());
- TORCH_CHECK(wq.dim() == 2,
- __func__, " : Expected wq argument to be 2D tensor, got ",
- wq.dim(), " dims");
- TORCH_CHECK(wq.layout() == at::Layout::Strided,
- __func__, " : Expected wq argument to be strided, got layout ",
- wq.layout());
- TORCH_CHECK(w_scale.dim() == 1 || w_scale.dim() == 2,
- __func__, " : Expected wq scale argument to be 1D or 2D tensor, ",
- "got ", w_scale.dim(), " dims");
- TORCH_CHECK(w_scale.layout() == at::Layout::Strided,
- __func__, " : Expected wq scale argument to be strided, got "
- "layout ", w_scale.layout());
- if (bias.numel() > 0) {
- TORCH_CHECK(bias.dim() == 1,
- __func__, " : Expected bias argument to be 1D tensor, got ",
- bias.dim(), " dims");
- TORCH_CHECK(bias.layout() == at::Layout::Strided,
- __func__, " : Expected bias argument to be strided, got ",
- "layout ", bias.layout());
- }
-
- // Validate sizes of arguments.
- const auto xq_sizes = xq.sizes().vec();
- TORCH_CHECK(xq_sizes.back() == 2 * wq.size(1),
- __func__, " : Expected xq argument to have ", 2 * wq.size(1),
- " columns, but got ", xq_sizes.back());
- const auto x_scale_sizes = x_scale.sizes().vec();
- for (auto i = 0; i < x_scale_sizes.size(); ++i)
- TORCH_CHECK(x_scale_sizes[i] == xq_sizes[i],
- __func__, " : Expected xq scale argument size at position ",
- i, " to be ", xq_sizes[i], ", but got ", x_scale_sizes[i]);
- TORCH_CHECK(w_scale.numel() == wq.size(0),
- __func__, " : Expected wq scale argument to have ", wq.size(0),
- " elements, got ", w_scale.numel(), " elements");
- if (bias.numel() > 0) {
- TORCH_CHECK(bias.numel() == wq.size(0),
- __func__, " : Expected bias argument to have ", wq.size(0),
- " elements, got ", bias.numel(), " elements");
- }
-
- // Validate strides of arguments.
- const auto xq_strides = xq.strides();
- TORCH_CHECK(xq_strides[xq_strides.size() - 1] == 1,
- __func__, " : Expected xq argument in row-major layout");
- auto xq_stride_expected = xq_strides[xq_strides.size() - 2];
- for (int i = xq_strides.size() - 3; i >= 0; --i) {
- xq_stride_expected *= xq_sizes[i + 1];
- TORCH_CHECK(xq_strides[i] == xq_stride_expected,
- __func__, " : Expected xq argument in row-major layout");
- }
- TORCH_CHECK(x_scale.is_contiguous(),
- __func__, " : Expected xq scale argument to be contiguous");
- const auto wq_strides = wq.strides();
- TORCH_CHECK(wq_strides[0] >= 1 && wq_strides[1] == 1,
- __func__, " : Expected wq argument in row-major layout");
- TORCH_CHECK(w_scale.is_contiguous(),
- __func__, " : Expected wq scale argument to be contiguous");
- if (bias.numel() > 0) {
- const auto bias_strides = bias.strides();
- TORCH_CHECK(bias_strides[0] == 1,
- __func__, " : Expected bias argument to be contiguous");
- }
-}
-#endif
-
-// Perform linear operation, using corresponding CUTLASS mixed
-// data-types GEMM kernel, to given arguments:
-// result = (xq * x_scale) @ (wq * w_scale).T + bias
-// Notes: The "x_scale" tensor is expected to be a vector, of size
-// equal to number of rows of "xq" tensor. The "w_scale" tensor is
-// expected to be a vector, of size equal to number of rows of "wq"
-// tensor. The "bias" tensor is expected to be a vector, of size equal
-// to number of rows of "wq" tensor.
-at::Tensor
-s8s4_linear_cutlass(
- const at::Tensor& xq, const at::Tensor& x_scale, const at::Tensor& wq,
- const at::Tensor& w_scale, const at::Tensor& bias) {
-#if defined(BUILD_S8S4_LINEAR_CUTLASS)
- // Check inputs.
- check_inputs(xq, x_scale, wq, w_scale, bias);
-
- // Squash the input tensors as appropriate.
- const auto xq_sizes = xq.sizes().vec();
- const auto xq_2d = xq.reshape({-1, xq_sizes.back()});
- const auto x_scale_sizes = x_scale.sizes().vec();
- const auto x_scale_1d = x_scale.reshape({-1});
- const auto w_scale_1d = w_scale.reshape({-1});
-
- // Introduce alias names for arguments, according to the CUTLASS
- // naming conventions.
- const auto& tensor_a = xq_2d;
- const auto& tensor_a_scale = x_scale_1d;
- const auto& tensor_b = wq;
- const auto& tensor_b_scale = w_scale_1d;
- const auto& tensor_c = bias;
-
- // Create output tensor.
- at::Tensor tensor_d =
- tensor_a_scale.new_empty({tensor_a.size(0), tensor_b.size(0)});
-
- // Dispatch to appropriate kernel template.
- dispatch_on_tensor_a_scale_and_tensor_b_scale(
- tensor_a, tensor_a_scale, tensor_b, tensor_b_scale, tensor_c, tensor_d);
-
- // Reshape and return output tensor.
- auto tensor_d_sizes = xq_sizes;
- tensor_d_sizes.back() = wq.size(0);
- return tensor_d.reshape(tensor_d_sizes);
-#else
- TORCH_CHECK_NOT_IMPLEMENTED(false, __func__);
- return at::Tensor{};
-#endif
-}
-
-TORCH_LIBRARY_IMPL(torchao, CUDA, m) {
- m.impl("torchao::s8s4_linear_cutlass", &s8s4_linear_cutlass);
-}
-
-} // namespace torchao
diff --git a/torchao/csrc/cuda/s8s4_linear_cutlass/scaled_linear.h b/torchao/csrc/cuda/s8s4_linear_cutlass/scaled_linear.h
deleted file mode 100644
index 991384b..0000000
--- a/torchao/csrc/cuda/s8s4_linear_cutlass/scaled_linear.h
+++ /dev/null
@@ -1,288 +0,0 @@
-#include <torch/extension.h>
-
-#include <ATen/ATen.h>
-#include <ATen/core/Tensor.h>
-#include <ATen/cuda/CUDAUtils.h>
-#include <c10/util/Exception.h>
-
-#include <cuda_runtime.h>
-#include <cutlass/cutlass.h>
-#include <cutlass/gemm/device/gemm_universal.h>
-#include <cutlass/epilogue/threadblock/fusion/visitors.hpp>
-#include <cutlass/gemm/kernel/default_gemm_universal_with_visitor.h>
-#include "cutlass/gemm/device/gemm_universal_adapter.h"
-
-#define CUTLASS_STATUS_CHECK(status) \
- { \
- TORCH_CHECK(status == cutlass::Status::kSuccess, \
- __func__, " : Got CUTLASS error: ", \
- cutlassGetStatusString(status)); \
- }
-
-namespace torchao {
-
-template<
- typename ThreadblockShape,
- typename WarpShape,
- typename InstructionShape,
- int NumStages,
- typename ThreadblockSwizzle,
- typename ElementA,
- typename ElementB,
- typename ElementAccumulator,
- typename Operator,
- typename ElementAScale,
- typename ElementBScale,
- typename ElementC,
- typename UseTensorC,
- typename ElementOutput>
-void scaled_linear_kernel_cutlass_sm8x(
- const at::Tensor& tensor_a, const at::Tensor& tensor_a_scale,
- const at::Tensor& tensor_b, const at::Tensor& tensor_b_scale,
- const at::Tensor& tensor_c, at::Tensor& tensor_d) {
- using SmArch = cutlass::arch::Sm80;
-
- using LayoutA = cutlass::layout::RowMajor;
- using LayoutB = cutlass::layout::ColumnMajor;
- using LayoutOutput = cutlass::layout::RowMajor;
-
- using ElementEpilogue = float;
- constexpr auto NumEVTEpilogueStages = 1;
-
- const int m = tensor_a.size(0);
- const int n = tensor_b.size(0);
- const int k = std::is_same<ElementA, cutlass::int4b_t>::value ?
- tensor_a.size(1) * 2 :
- tensor_a.size(1);
-
- constexpr int AlignmentA = 128 / cutlass::sizeof_bits<ElementA>::value;
- constexpr int AlignmentAScale =
- 128 / cutlass::sizeof_bits<ElementAScale>::value;
- constexpr int AlignmentB = 128 / cutlass::sizeof_bits<ElementB>::value;
- constexpr int AlignmentBScale =
- 128 / cutlass::sizeof_bits<ElementBScale>::value;
- constexpr int AlignmentC = 128 / cutlass::sizeof_bits<ElementC>::value;
- constexpr int AlignmentOutput =
- 128 / cutlass::sizeof_bits<ElementOutput>::value;
-
- // Check for current CUTLASS limitations w.r.t. alignments.
- TORCH_CHECK(k % AlignmentA == 0,
- __func__, " : Number of columns of tensor A must be divisible ",
- "by ", AlignmentA);
- TORCH_CHECK(k % AlignmentB == 0,
- __func__, " : Number of columns of tensor B must be divisible ",
- "by ", AlignmentB);
- TORCH_CHECK(n % AlignmentC == 0,
- __func__, " : Number of columns of tensor C must be divisible ",
- "by ", AlignmentC);
-
- using TensorAScaleTileThreadMap =
- cutlass::epilogue::threadblock::OutputTileThreadLayout<
- ThreadblockShape,
- WarpShape,
- ElementAScale,
- AlignmentAScale,
- NumEVTEpilogueStages>;
- using TensorBScaleTileThreadMap =
- cutlass::epilogue::threadblock::OutputTileThreadLayout<
- ThreadblockShape,
- WarpShape,
- ElementBScale,
- AlignmentBScale,
- NumEVTEpilogueStages>;
- using TensorCTileThreadMap =
- cutlass::epilogue::threadblock::OutputTileThreadLayout<
- ThreadblockShape,
- WarpShape,
- ElementC,
- AlignmentC,
- NumEVTEpilogueStages>;
- using OutputTileThreadMap =
- cutlass::epilogue::threadblock::OutputTileThreadLayout<
- ThreadblockShape,
- WarpShape,
- ElementOutput,
- AlignmentOutput,
- NumEVTEpilogueStages>;
-
- using Accum = cutlass::epilogue::threadblock::VisitorAccFetch;
-
- using TensorAScale =
- cutlass::epilogue::threadblock::VisitorColBroadcast<
- TensorAScaleTileThreadMap,
- ElementAScale,
- cute::Stride<cute::_1, cute::_0, int64_t>>;
- using TensorAScaleArguments = typename TensorAScale::Arguments;
-
- using TensorBScale =
- cutlass::epilogue::threadblock::VisitorRowBroadcast<
- TensorBScaleTileThreadMap,
- ElementBScale,
- cute::Stride<cute::_0, cute::_1, int64_t>>;
- using TensorBScaleArguments = typename TensorBScale::Arguments;
-
- using TensorCScalar =
- cutlass::epilogue::threadblock::VisitorScalarBroadcast<ElementC>;
- using TensorCTensor =
- cutlass::epilogue::threadblock::VisitorRowBroadcast<
- TensorCTileThreadMap,
- ElementC,
- cute::Stride<cute::_0, cute::_1, int64_t>>;
- using TensorC =
- std::conditional_t<UseTensorC::value, TensorCTensor, TensorCScalar>;
- using TensorCArguments = typename TensorC::Arguments;
-
- using ApplyAScale = cutlass::epilogue::threadblock::VisitorCompute<
- cutlass::multiplies, ElementEpilogue, ElementEpilogue,
- cutlass::FloatRoundStyle::round_to_nearest
- >;
- using EVTApplyAScale = cutlass::epilogue::threadblock::Sm80EVT<
- ApplyAScale,
- Accum,
- TensorAScale>;
-
- using ApplyBScale = cutlass::epilogue::threadblock::VisitorCompute<
- cutlass::multiplies, ElementEpilogue, ElementEpilogue,
- cutlass::FloatRoundStyle::round_to_nearest
- >;
- using EVTApplyBScale = cutlass::epilogue::threadblock::Sm80EVT<
- ApplyBScale,
- EVTApplyAScale,
- TensorBScale>;
-
- using ApplySum = cutlass::epilogue::threadblock::VisitorCompute<
- cutlass::plus, ElementEpilogue, ElementEpilogue,
- cutlass::FloatRoundStyle::round_to_nearest
- >;
- using EVTApplySum = cutlass::epilogue::threadblock::Sm80EVT<
- ApplySum,
- EVTApplyBScale,
- TensorC>;
-
- using Output = cutlass::epilogue::threadblock::VisitorAuxStore<
- OutputTileThreadMap, ElementOutput,
- cutlass::FloatRoundStyle::round_to_nearest,
- cute::Stride<int64_t, cute::_1, int64_t> // StrideMNL
- >;
-
- using EVTOutput = cutlass::epilogue::threadblock::Sm80EVT<
- Output,
- EVTApplySum>;
-
- using EVTKernel =
- typename cutlass::gemm::kernel::DefaultGemmWithVisitor<
- ElementA, LayoutA, cutlass::ComplexTransform::kNone, AlignmentA,
- ElementB, LayoutB, cutlass::ComplexTransform::kNone, AlignmentB,
- ElementOutput, LayoutOutput, AlignmentOutput,
- ElementAccumulator,
- ElementEpilogue,
- cutlass::arch::OpClassTensorOp,
- SmArch,
- ThreadblockShape,
- WarpShape,
- InstructionShape,
- EVTOutput,
- ThreadblockSwizzle,
- NumStages,
- Operator,
- NumEVTEpilogueStages
- >::GemmKernel;
-
- // GemmUniversalBase doesn't work with W4A4
- // using Gemm = cutlass::gemm::device::GemmUniversalBase<EVTKernel>;
- using Gemm = cutlass::gemm::device::GemmUniversalAdapter<EVTKernel>;
-
- cutlass::gemm::GemmCoord problem_size(m, n, k);
- constexpr auto SplitKFactor = 1;
-
- TensorAScaleArguments tensor_a_scale_arguments{
- (ElementAScale*)tensor_a_scale.data_ptr(),
- ElementAScale(1),
- {cute::_1{}, cute::_0{}, problem_size.m()}
- };
- TensorBScaleArguments tensor_b_scale_arguments{
- (ElementBScale*)tensor_b_scale.data_ptr(),
- ElementBScale(1),
- {cute::_0{}, cute::_1{}, problem_size.n()}
- };
- TensorCArguments tensor_c_arguments{
- [&]() -> TensorCArguments {
- if constexpr (UseTensorC::value) {
- return {(ElementC*)tensor_c.data_ptr(),
- ElementC(0),
- {cute::_0{}, cute::_1{}, problem_size.n()}};
- } else {
- return {ElementC(0)};
- }
- }()
- };
- typename Output::Arguments output_arguments{
- (ElementOutput*)tensor_d.data_ptr(),
- {problem_size.n(), cute::_1{}, problem_size.mn().product()}
- };
- typename EVTOutput::Arguments callback_arguments{
- {
- {
- {
- {}, // Accum
- tensor_a_scale_arguments, // TensorAScale
- {} // ApplyAScale
- }, // EVTApplyAScale
- tensor_b_scale_arguments, // TensorBScale
- {}, // ApplyBScale
- }, // EVTApplyBScale
- tensor_c_arguments, // TensorC
- {} // ApplySum
- }, // EVTApplySum
- output_arguments // Output
- }; // EVTOutput
- // constexpr auto AvailSms = -1;
-
- typename Gemm::Arguments arguments(
- cutlass::gemm::GemmUniversalMode::kGemm,
- problem_size,
- SplitKFactor,
- callback_arguments, // arguments of EVT callbacks
- (ElementA*)tensor_a.data_ptr(),
- (ElementB*)tensor_b.data_ptr(),
- nullptr, // ptr C (unused)
- nullptr, // ptr D (unused)
- problem_size.mk().product(), // batch stride A
- problem_size.nk().product(), // batch stride B
- 0, // batch stride C (unused)
- 0, // batch stride D (unused)
- problem_size.k(), // stride A
- problem_size.k(), // stride B
- 0, // stride C (unused)
- 0
- // , // stride D (unused)
- // AvailSms // GemmUniversalBase requires passing AvailSms, but GemmUniversalAdapter doesn't
- );
-
- Gemm gemm_op;
-
- cutlass::Status status;
-
- // Verify that GEMM operation with given arguments can be performed
- // by CUTLASS.
- status = gemm_op.can_implement(arguments);
- CUTLASS_STATUS_CHECK(status);
-
- // Allocate workspace for CUTLASS mixed datatypes GEMM kernel.
- const auto workspace_size = Gemm::get_workspace_size(arguments);
- auto workspace = tensor_a.new_empty({(int64_t)workspace_size},
- at::TensorOptions().dtype(at::kByte));
-
- // Initialize CUTLASS mixed datatypes GEMM object.
- status = gemm_op.initialize(arguments, workspace.data_ptr(),
- at::cuda::getCurrentCUDAStream());
- CUTLASS_STATUS_CHECK(status);
-
- // Perform mixed datatypes GEMM operation.
- status = gemm_op.run(at::cuda::getCurrentCUDAStream());
- CUTLASS_STATUS_CHECK(status);
-
- C10_CUDA_KERNEL_LAUNCH_CHECK();
-}
-
-} // namespace torchao
diff --git a/torchao/dtypes/uintx/cutlass_int4_packed_layout.py b/torchao/dtypes/uintx/cutlass_int4_packed_layout.py
index d7374c8..bccbf80 100644
--- a/torchao/dtypes/uintx/cutlass_int4_packed_layout.py
+++ b/torchao/dtypes/uintx/cutlass_int4_packed_layout.py
@@ -144,14 +144,14 @@ def _linear_int8_act_int4_weight_cutlass_check(input_tensor, weight_tensor, bias
def _linear_int8_act_int4_weight_cutlass_impl(input_tensor, weight_tensor, bias):
- from torchao.ops import s8s4_linear_cutlass
+ from torchao.ops import rowwise_scaled_linear_cutlass_s8s4
weight = weight_tensor.tensor_impl.int_data
weight_scale = weight_tensor.tensor_impl.scale
input = input_tensor.tensor_impl.int_data
input_scale = input_tensor.tensor_impl.scale
- out = s8s4_linear_cutlass(input, input_scale, weight, weight_scale, bias)
+ out = rowwise_scaled_linear_cutlass_s8s4(input, input_scale, weight, weight_scale, bias)
return out
diff --git a/torchao/ops.py b/torchao/ops.py
index 840dbc0..272b358 100644
--- a/torchao/ops.py
+++ b/torchao/ops.py
@@ -20,11 +20,10 @@ lib.define(
"marlin_qqq_gemm(Tensor x, Tensor weight_marlin, Tensor s_tok, Tensor s_ch, Tensor s_group, Tensor workspace, int size_m, int size_n, int size_k) -> Tensor"
)
lib.define(
- "s8s4_linear_cutlass(Tensor input, Tensor input_scale, Tensor weight, Tensor weight_scale, Tensor bias) -> Tensor"
+ "rowwise_scaled_linear_cutlass_s4s4(Tensor input, Tensor input_scale, Tensor weight, Tensor weight_scale, Tensor bias) -> Tensor"
)
-lib.define("int4_mm_cutlass(Tensor A, Tensor B) -> Tensor")
lib.define(
- "scaled_int4_mm_cutlass(Tensor A, Tensor B, Tensor row_scale, Tensor col_scale) -> Tensor"
+ "rowwise_scaled_linear_cutlass_s8s4(Tensor input, Tensor input_scale, Tensor weight, Tensor weight_scale, Tensor bias) -> Tensor"
)
@@ -518,7 +517,7 @@ def _(
return torch.empty((size_m, size_n), dtype=torch.float16, device=x.device)
-def s8s4_linear_cutlass(
+def rowwise_scaled_linear_cutlass_s8s4(
input: Tensor,
input_scale: Tensor,
weight: Tensor,
@@ -526,23 +525,23 @@ def s8s4_linear_cutlass(
bias: Tensor,
) -> Tensor:
"""
- CUTLASS-based W4A8 linear operator.
+ CUTLASS-based row-wise scaled linear operator.
Args:
- input: input tensor, quantized to 8-bit integer values.
+ input: quantized input tensor, in row-major layout.
input_scale: scale factors for input tensor, has to be tensor of the same shape as the input tensor, minus the last dimension.
- weight: weight matrix, quantized to 4-bit integer values, in row-major layout.
+ weight: quantized weight matrix, in row-major layout.
weight_scale: scale factors for weight tensor, one value per row of weight matrix (thus also tensor of the same shape as the weight tensor, minus the last dimension).
bias: a vector of size equal to number of rows of weight tensor, or None.
Returns:
output: result tensor, in row-major layout.
"""
- return torch.ops.torchao.s8s4_linear_cutlass.default(
+ return torch.ops.torchao.rowwise_scaled_linear_cutlass_s8s4.default(
input, input_scale, weight, weight_scale, bias
)
-@register_custom_op("torchao::s8s4_linear_cutlass")
+@register_custom_op("torchao::rowwise_scaled_linear_cutlass_s8s4")
def _(
input: Tensor,
input_scale: Tensor,
@@ -550,6 +549,8 @@ def _(
weight_scale: Tensor,
bias: Tensor,
) -> Tensor:
+ # FIXME: update this!!!
+
# Validate dtypes.
torch._check(
input.dtype == torch.int8,
@@ -621,29 +622,8 @@ def _(
)
-def int4_mm_cutlass(A: Tensor, B: Tensor) -> Tensor:
- """
- CUTLASS-based W4A4 matmul.
- Args:
- A: first INT4 tensor, packed in INT8 dtype, row-major layout.
- B: second INT4 tensor, packed in INT8 dtype, column-major layout.
- Returns:
- output: result tensor, in row-major layout.
- """
- assert A.dtype == B.dtype == torch.int8
- assert A.ndim == B.ndim == 2
- assert A.shape[1] == B.shape[0]
- assert A.is_contiguous() and B.T.is_contiguous()
- return torch.ops.torchao.int4_mm_cutlass.default(A, B)
-
-
-@register_custom_op("torchao::int4_mm_cutlass")
-def _(A: Tensor, B: Tensor) -> Tensor:
- return A.new_empty(A.shape[0], B.shape[1], dtype=torch.int32)
-
-
-def scaled_int4_mm_cutlass(
- A: Tensor, B: Tensor, row_scale: Tensor, col_scale: Tensor
+def rowwise_scaled_linear_cutlass_s4s4(
+ A: Tensor, row_scale: Tensor, B: Tensor, col_scale: Tensor, bias: Tensor
) -> Tensor:
"""
CUTLASS-based W4A4 scaled-matmul.
@@ -656,15 +636,16 @@ def scaled_int4_mm_cutlass(
output: result tensor, in row-major layout.
"""
assert A.dtype == B.dtype == torch.int8
- assert A.ndim == B.ndim == 2
- assert A.shape[1] == B.shape[0]
- assert A.is_contiguous() and B.T.is_contiguous()
- assert row_scale.ndim == col_scale.ndim == 1
+ assert A.ndim >= 2
+ assert B.ndim == 2
+ assert A.shape[-1] == B.shape[-1]
+ assert A.is_contiguous() and B.is_contiguous()
+ assert row_scale.ndim == col_scale.ndim == 2
assert row_scale.dtype == col_scale.dtype
assert row_scale.dtype in (torch.float16, torch.bfloat16)
- return torch.ops.torchao.scaled_int4_mm_cutlass.default(A, B, row_scale, col_scale)
+ return torch.ops.torchao.rowwise_scaled_linear_cutlass_s4s4.default(A, row_scale, B, col_scale, bias)
-@register_custom_op("torchao::scaled_int4_mm_cutlass")
-def _(A: Tensor, B: Tensor, row_scale: Tensor, col_scale: Tensor) -> Tensor:
+@register_custom_op("torchao::rowwise_scaled_linear_cutlass_s4s4")
+def _(A: Tensor, row_scale: Tensor, B: Tensor, col_scale: Tensor, bias: Tensor) -> Tensor:
return row_scale.new_empty(A.shape[0], B.shape[1]) I did some renaming, so Please take a look and let me know what you think. Regarding your question about possibility of different types for scales and bias - this pretty much comes for free, and I would be inclined to extend this further so that either of input/weight scales and/or bias could be of different types. On the other side, I have a related question here: looking at your My TODO list regarding the attached patch:
I'll continue on this tomorrow. |
Thank you for the patch! Will apply and work on it later today.
I don't think there is a standalone real use case for INT4xINT4->INT32 (yet), similarly for INT8xINT8->INT32 i.e.
From my observation, it's a good idea to have input checks for the meta impletation used by torch.compile for shape tracing (otherwise we won't catch errors during compile time, only runtime) -> we need some kind of input checks in Python anyway. Hence, usually I don't bother with input checks in C++ (or triton) side, and put all checks in Python. I place the Python checks in the Python wrapper so that it is shared between meta and CUDA implementations (actually it's regardless of the implementation) -> technically the op doesn't have input checks, but the Python wrapper does. e.g. ao/torchao/prototype/quantized_training/int8_mm.py Lines 137 to 150 in 5d1444b
That's just my personal view (and I prefer writing Python over C++ any time 😆). |
Found a typo + int k = tensor_a.size(1);
+ if constexpr (cutlass::sizeof_bits<ElementA>::value < 8) {
+ k *= 8 % cutlass::sizeof_bits<ElementA>::value;
+ } Should be (integer) division instead of modulo. After I fixed that, the tests pass. I also take the liberty to slightly change the test: now matmul, scaling, and bias addition are done in FP32, which should match the numerics of the fused cutlass kernel better, especially for the scaling and bias addition parts, since FP16/BF16 matmul is acccumulated with FP32 anyway. With this change, we can use With your refactoring, it should be trivial to add W8A8. Should we add W8A8 also? (We might not add it to AQT since AQT currently uses torch.compile to codegen this path. Though I think inductor only fuses 1 scaling (either input scale or weight scale), and we can get a bit more perf with fusing both input and weight scales. Other ideas:
|
I'd prefer as much as possible CUTLASS-based kernels to be implemented through what we came up with; however, for now I'd restrain to what is asked by users (which means we stay at the moment with W4A4/W4A8, and extend it on demand). The reasons is that, through another CUTLASS-based kernel I'm working on, I hope to investigate a more systematic approach to selecting performant configs, and then eventually to update these two variations that we have, and also to have a ready-made approach for new kernels. Good config selection is the hardest thing here - for W4A8 kernel, I've spent considerable time benchmarking manually, in order to be come up with corresponding config selection code, and I still am not very happy with it. As far as using these kernels as a base for benchmarking: my approach throughout the development is to compare against
My understanding matches what you said. So when the code run in eager mode, the corresponding operator from So I believe that, if we keep checks in C++ code, we can remove all the checks from
Awesome, sorry for the bug. The tests look great now. Shall we do some effort to extract common stuff between tests? I can do that.
I've implemented this through this PR, it's merged in PyTorch. I agree it belongs here, and in the future I'd like to migrate this one as well as some other kernels (I have some CUTLASS-based F16/S8 stuff there, implemented before torchao started) into torchao.
Completely agree with that. There is some support for CUTLASS-based auto-tuning, and I did some work there; but it sort of predates most of CUTLASS Python API development, so it's based on pretty much manually generating all of the C++ code that is alike to what we have in the For this PR, I'll try to complete remaining items from my list:
Will post a patch here. |
@alexsamardzic Thank you for the update. Was not aware of the latest works in PyTorch core :)
Just want to point out a small note. Having checks under Btw, it feels wrong to have you do most of the hard part, but the commit is by me. Can you commit directly to my branch? If not, I can add you as a collaborator of my fork. |
You're right, and it kind of sucks. On the other side, if error is reported by the tracer, the error message says something like
Don't worry about that - it's fine with me, and I should have just one small patch to add. |
Here it is: |
Closes #1406
Thanks to #880, we now have a CUTLASS (3.6.0) copy in torchao. Adding W4A4 is pretty straight-forward, similar to how W4A8 is done. This is largely copied from my other repo, so I didn't exactly follow @alexsamardzic's style. Requesting a first round of review.
Note: this is more for doing experiments with W4A4 easier. Personally I don't think it's too useful at the moment, since W4A4 accuracy is probably quite bad.
TODO: