Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add CUTLASS-based W4A4 #1515

Draft
wants to merge 14 commits into
base: main
Choose a base branch
from
Draft

Add CUTLASS-based W4A4 #1515

wants to merge 14 commits into from

Conversation

gau-nernst
Copy link
Collaborator

@gau-nernst gau-nernst commented Jan 7, 2025

Closes #1406

Thanks to #880, we now have a CUTLASS (3.6.0) copy in torchao. Adding W4A4 is pretty straight-forward, similar to how W4A8 is done. This is largely copied from my other repo, so I didn't exactly follow @alexsamardzic's style. Requesting a first round of review.

Note: this is more for doing experiments with W4A4 easier. Personally I don't think it's too useful at the moment, since W4A4 accuracy is probably quite bad.

TODO:

  • Hook up to AQT
  • Benchmark script and get benchmark results for A100 and 4090
  • (Maybe) Do some tuning + heuristics based on problem size

Copy link

pytorch-bot bot commented Jan 7, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/1515

Note: Links to docs will display an error until the docs builds have been completed.

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Jan 7, 2025
@gau-nernst gau-nernst added the topic: new feature Use this tag if this PR adds a new feature label Jan 7, 2025
@alexsamardzic
Copy link
Collaborator

CUDA code looks fine, of course there are lots of dots to connect remaining on the Python side.

The difference from #880 is that this is not mixed data types GEMM, but regular GEMM instead. In that regard, this operator here is maybe easier to be made much more generic, to support other integer and maybe even some floating point input data types. I'm at the moment making some minor changes on this PyTorch operator, and would strongly recommend modelling CUDA code in alike way, as it plain looks nice, and then makes extending the kernel to other datatypes much easier, has extensive checks on operands, etc. Moreover, I think it would make sense at this point to discuss having a single CUTLASS-based kernel for GEMMs with both weights and activations scaled, to be put in the single source file, and to handle both same and mixed data types GEMMs, at least for SM 8.x archs - that would provide for minimum code duplication, and easier maintenance in the future.

As far as configurations (tile sizes, number of stages, etc.) concerned, I'd suggest looking here instead in the unit tests, and also comparing performance vs. results reported by CUTLASS profiler for given combination of data types. I believe some sort of tuning configuration on the input shapes is a must in order to achieve a decent performance; but I have to admit that in #880 the tuning is mostly ad-hoc (for comparison, I find this approach more elaborate and meaningful). Thus, I think that coming up with some kind of systematic approach in that regard would be the most beneficial contribution regarding eventual future use of CUTLASS-based kernels in the torchao.

(@drisspg: Your comments welcome here.)

@drisspg
Copy link
Contributor

drisspg commented Jan 7, 2025

One thing on finding optimal params is that @yifuwang was recently working on finding better configs for an AsyncMM. He did some manual elimination of configs that never seemed to be performant and then fit a simple decision Tree on a big sweep over MKN shapes that could be easily modeled in C++. This is similar to what is done in the RowWise scaling. I think a little flow for this would be helpful I can make an issue to track.

No major comments

@gau-nernst
Copy link
Collaborator Author

Thank you for the feedback.

In that regard, this operator here is maybe easier to be made much more generic, to support other integer and maybe even some floating point input data types

Though this is nice on paper, I think Triton is the better alternative for other data types (INT8, FP8...). It's more flexible and the autotuner also saves us some headache. Only because of the lack of INT4 support in Triton, we have to use Cutlass, especially for INT4 Tensor cores. Unless we can show that there are cases Triton cannot reach the perf of Cutlass (in the context of this PR, I'm only thinking about INT8 for SM8x, and additionally FP8 for SM89).

Having said that, I'm ok with following a certain style/structure. Just point me which one it should be, and I will make modifications accordingly.

torchao/ops.py Outdated
Returns:
output: result tensor, in row-major layout.
"""
assert A.dtype == B.dtype == torch.int8
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should add the alignment constraints as well right?

Copy link
Collaborator Author

@gau-nernst gau-nernst Jan 8, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How should I check for data alignment from Python? I guess in C++, I can check by testing divisibility of the memory address? (or perhaps there is a util function somewhere that I'm not aware of...)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm I think there is a restriction that k need to be a multiple of 32 right?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Or at least 16 packed int4 s

using InstructionShape = cutlass::gemm::GemmShape<16, 8, 64>;
// static int const kStages = 3;
using ElementC = int32_t;
using Gemm = cutlass::gemm::device::Gemm<
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you know if the universal gemm api can be used?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will look into it. I wrote this quite some time ago...

@alexsamardzic
Copy link
Collaborator

@gau-nernst

Attached is a minor patch that will change s8s4_linear_cutlass operator to do W4A4. I did it in a quick-a-dirty way, so maybe I made some kind of error (the tests run, but won't pass), but the point is that the main differences are in: checking arguments (I just commented out these), dispatching depending on input data types (I just changed the input data type in this patch instead of having full dispatching) and selecting configuration (for W4A4 I just put the same configuration there that you use in your patch). But the CUTLASS boilerplate code is completely the same, except for using OpMultiplyAddMixedInputUpcast as operator for mixed input data types case. So my point is that I think, for maintenance reasons, we better have single CUTLASS-based kernel for W4A8 and W4A4 (and then all of alike) instead of duplicating lots of code.

The structure of CUTLASS-based kernels is typically always the same (see also rowwise scaled MM in PyTorch, mentioned in my previous comment, as well as my CUTLASS-based mixed data types and 2:4 sparsity kernels in PyTorch): from the bottom up, there is always an operator implementation function that contains checking inputs, and then starting a dispatching chain (where run-time data types etc. are translated to compile-time template arguments), that ends up with a typical CUTLASS-based GEMM kernel (that is boilerplate). Also as mentioned in my previous comment, while rowwise scaled MM is very similar in structure, I like how it looks the most - because of clever use of variable template arguments to decrease the clutter, then because of clear extraction of input checks, and configuration selection into separate functions, etc. So I'd suggest we have your C++ code integrated in the way sketched by attached diff, and then also to made minor changes in the C++ code in a way to make it to look closer to rowwise scaled MM implementation. (Of course, operator name and some other stuff on Python side will have to be changed too.)

diff.txt

@alexsamardzic
Copy link
Collaborator

alexsamardzic commented Jan 8, 2025

As far as performance between various implementations concerned: I'd say in general there are three ways to implement kernels: Triton-based, CUTLASS-based, and custom i.e. from scratch (like Marlin-based kernels). In my experience so far (that was all for Ampere arch), CUTLASS-based kernels are oftentimes somewhat faster than Triton-based kernels, while then for some corner-case input tensor sizes, custom kernels (well, Marlin-based at least) could be significantly faster than CUTLASS-based ones. Furthermore, with Triton there is the least amount of flexibility with upstream changes (they just don't support some input data types, they don't support 2:4 sparsity, etc.), with CUTLASS it's somewhat easier to have changes we may need accepted, while for custom kernels obviously this is not an issue at all. However, Triton kills it when it comes to compilation, in particular regarding fusing GEMM with other kernels, then CUTLASS has some support for compilation but doing fusion is rather cumbersome at the moment, while obviously there is no any kind of compilation support for custom kernels. Then, doing custom kernels would probably lead to lots of code duplication, with CUTLASS this also may be an issue even if to the smaller extent. Etc. - so it's all matter of trade-offs. Still, having in mind auto-tuning and auto-quantization, I belive it still may be good to have as much different kernels in torchao as possible, so I'd expect more CUTLASS-based kernels to be written, besides these W4A8 and W4A4 kernels - and this is the exact reason that, as discussed above, I'd prefer to have as much code shared as possible between these kernels.

@supriyar
Copy link
Contributor

supriyar commented Jan 9, 2025

since W4A4 accuracy is probably quite bad.

Might be interesting to try out QAT with this setting cc @andrewor14

@alexsamardzic
Copy link
Collaborator

The structure of CUTLASS-based kernels is typically always the same (see also rowwise scaled MM in PyTorch, mentioned in my previous comment, as well as my CUTLASS-based mixed data types and 2:4 sparsity kernels in PyTorch): from the bottom up, there is always an operator implementation function that contains checking inputs, and then starting a dispatching chain (where run-time data types etc. are translated to compile-time template arguments), that ends up with a typical CUTLASS-based GEMM kernel (that is boilerplate). Also as mentioned in my previous comment, while rowwise scaled MM is very similar in structure, I like how it looks the most - because of clever use of variable template arguments to decrease the clutter, then because of clear extraction of input checks, and configuration selection into separate functions, etc. So I'd suggest we have your C++ code integrated in the way sketched by attached diff, and then also to made minor changes in the C++ code in a way to make it to look closer to rowwise scaled MM implementation. (Of course, operator name and some other stuff on Python side will have to be changed too.)

I've made these changes to existing CUTLASS-based W4A8 kernel in #1545, so it should be easier now to eventually include W4A4 functionality there.

torchao/csrc/cuda/s8s4_linear_cutlass/scaled_linear.h Outdated Show resolved Hide resolved
torchao/csrc/cuda/s8s4_linear_cutlass/scaled_linear.h Outdated Show resolved Hide resolved
torchao/csrc/cuda/s8s4_linear_cutlass/scaled_linear.h Outdated Show resolved Hide resolved
const auto is_sm8x = dprops->major >= 8;

if (is_sm8x) {
using ElementA = cutlass::int4b_t;
Copy link
Collaborator

@alexsamardzic alexsamardzic Jan 22, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is the reason to skip dispatch_on_tensor_a_and_tensor_b step here? I've put this one there exactly to be able to easily choose ElementA and ElementB template arguments.

If this file changed to have dispatch_on_tensor_a_and_tensor_b, then s4s4_linear_cutlass.cu and s8s4_linear_cutlass.cu would be pretty much the same, except for different datatypes pair ElementA/ElementB, and select_config would differ. For this reason, my idea with latest changes in s8s4_linear_cutlass.cu was actually to keep everything in single file. There would be one operator visible to Python side of torchao, instead of s8s4_linear_cutlass, we'd could name it rowwise_scaled_linear_cutlass, and it could handle S8/S4, S4/S4, and in the future pretty much any combination of same or mixed data types, both integer and floating point (and also could be extended to handle sm90 and higher). The advantage is that there would be no duplicated code, and it would be easy to extend with new data types, the disadvantage is that large number of templates would be instantiated in single file so the torchao build times would suffer badly (I have no immediate idea, but I hope this could be somehow fixed, through some kind of explicit template instantiation). Please allow me for some time, and I'll try to come up with a patch to show what I mean...

@gau-nernst
Copy link
Collaborator Author

@alexsamardzic Thank you for your prompt feedback. Was going to work on this a little bit more before asking for another round of review.

I wasn't sure how much code-sharing you intend to have. Was trying to work with single file, but the biggest blocker is this one

  if (tensor_a.scalar_type() == at::ScalarType::Char) {
    if (tensor_b.scalar_type() == at::ScalarType::Char) {
      if (tensor_a.size(1) == 2 * tensor_b.size(1)) {

In dispatch_on_tensor_a_and_tensor_b(), you rely on tensor_a.size(1) == 2 * tensor_b.size(1) to detect W4A8. It means that there is no meaningful way to distinguish W4A4 and W8A8, unless there is an extra template argument, which kinda defeats the purpose of dispatch_on_tensor_a_and_tensor_b() since we should just call select_config() directly. Would require your advice on this.

That's why I decided to move the inner-most template to a separate header file for sharing. Wanted to move select_config() too, but since it's about hard-coding kernel params for a particular GEMM, I felt it should stay in GEMM-specific file.

I have another question. Currently you have dispatch_on_tensor_a_scale_and_tensor_b_scale() and dispatch_on_tensor_c(), which seems to suggest that you allow bias dtype != tensor scale dtype. Currently only FP16 and BF16 permutations are supported. I don't know if BF16 tensor scales + FP16 bias is useful/necessary? (and vice versa for FP16 tensor scales and BF16 bias). Furthermore, torch._scaled_mm() uses FP32 scales and allow specifying output dtype. Not exactly sure if FP32 scales are necessary, but just pointing out my observation here (might want to support FP32 scales also? Then perhaps it makes sense to support bias dtype != tensor scale dtype, where scale dtype = FP32 and bias dtype = FP16/BF16).

check_inputs() will also require some ways to tell that either A or B (or both) are int4, if you want to keep it single file.

Side comments. I don't mind code duplication. Realistically speaking, there are not that many useful combinations, so code duplication is not too bad. And you probably already know, for cutlass device-level API, not all combinations work (and when they don't work, there are mysterious errors 😅)

@alexsamardzic
Copy link
Collaborator

I realized, while making changes, that your point on dispatching on input/weight tensor data types is right. Namely, PyTorch doesn't provide sub-byte data types, so the only way to differentiate between data type combinations like S4/S4 and S8/S8 is to have this information encoded somehow else. Your approach was to encode it in the name of the operator, and this kind of approach could also solve the issue that I was worried about - that not all templates get instantiated from the single source file.

I made some further changes to maximize C++ code reuse. I don't mind code duplication either, but in this case it's (in my opinion) about really cumbersome boilerplate code, that for sanity I'd really much prefer to keep on single place. So I came up with a following patch, to be applied on top of current state of this branch:

Patch
diff --git a/benchmarks/benchmark_s8s4_cutlass.py b/benchmarks/benchmark_rowwise_scaled_linear_cutlass.py
similarity index 73%
rename from benchmarks/benchmark_s8s4_cutlass.py
rename to benchmarks/benchmark_rowwise_scaled_linear_cutlass.py
index fbf07eb..00bcb0a 100644
--- a/benchmarks/benchmark_s8s4_cutlass.py
+++ b/benchmarks/benchmark_rowwise_scaled_linear_cutlass.py
@@ -2,7 +2,7 @@ import pandas as pd
 import torch
 from tqdm import tqdm
 
-from torchao.ops import s8s4_linear_cutlass
+from torchao.ops import rowwise_scaled_linear_cutlass_s8s4
 from torchao.utils import benchmark_torch_function_in_microseconds
 
 
@@ -24,8 +24,8 @@ def benchmark(m: int, k: int, n: int):
     A_ref, B_ref, A, A_scale, B, B_scale, C = get_problem(m, n, k)
 
     fp16_time = benchmark_torch_function_in_microseconds(torch.matmul, A_ref, B_ref)
-    s8s4_linear_cutlass_time = benchmark_torch_function_in_microseconds(
-        s8s4_linear_cutlass, A, A_scale, B, B_scale, C
+    rowwise_scaled_linear_cutlass_s8s4_time = benchmark_torch_function_in_microseconds(
+        rowwise_scaled_linear_cutlass_s8s4, A, A_scale, B, B_scale, C
     )
 
     return {
@@ -33,8 +33,8 @@ def benchmark(m: int, k: int, n: int):
         "k": k,
         "n": n,
         "fp16_latency (ms)": fp16_time,
-        "s8s4_linear_cutlass latency (ms)": s8s4_linear_cutlass_time,
-        "speedup (d/s)": fp16_time / s8s4_linear_cutlass_time,
+        "rowwise_scaled_linear_cutlass latency (ms)": rowwise_scaled_linear_cutlass_s8s4_time,
+        "speedup (d/s)": fp16_time / rowwise_scaled_linear_cutlass_s8s4_time,
     }
 
 
@@ -48,5 +48,5 @@ if __name__ == "__main__":
             results.append(benchmark(m, k, n))
 
     df = pd.DataFrame(results)
-    df.to_csv("s8s4_linear_cutlass_time_results.csv", index=False)
+    df.to_csv("rowwise_scaled_linear_cutlass_s8s4_time_results.csv", index=False)
     print(df.to_markdown(index=False))
diff --git a/test/test_rowwise_scaled_linear_cutlass.py b/test/test_rowwise_scaled_linear_cutlass.py
new file mode 100644
index 0000000..422b100
--- /dev/null
+++ b/test/test_rowwise_scaled_linear_cutlass.py
@@ -0,0 +1,128 @@
+import itertools
+
+import pytest
+import torch
+
+from torchao.ops import (
+    rowwise_scaled_linear_cutlass_s4s4,
+    rowwise_scaled_linear_cutlass_s8s4,
+)
+from torchao.quantization.utils import group_quantize_tensor_symmetric
+from torchao.utils import compute_max_diff
+
+ROWWISE_SCALED_LINEAR_CUTLASS_DTYPE = [torch.float16, torch.bfloat16]
+ROWWISE_SCALED_LINEAR_CUTLASS_BATCH_SIZE = [1, 4, 8, 16, 32, 64]
+ROWWISE_SCALED_LINEAR_CUTLASS_SIZE_MNK = [
+    (2, 512, 128),
+    (3, 2048, 2048),
+    (4, 3584, 640),
+    (13, 8704, 8576),
+    (26, 18944, 1664),
+    (67, 6656, 1408),
+]
+ROWWISE_SCALED_LINEAR_CUTLASS_USE_BIAS = [False, True]
+ROWWISE_SCALED_LINEAR_CUTLASS_TEST_PARAMS = list(
+    itertools.product(
+        ROWWISE_SCALED_LINEAR_CUTLASS_DTYPE,
+        ROWWISE_SCALED_LINEAR_CUTLASS_BATCH_SIZE,
+        ROWWISE_SCALED_LINEAR_CUTLASS_SIZE_MNK,
+        ROWWISE_SCALED_LINEAR_CUTLASS_USE_BIAS,
+    )
+)
+
[email protected](not torch.cuda.is_available(), reason="CUDA not available")
[email protected](
+    "dtype, batch_size, size_mnk, use_bias", ROWWISE_SCALED_LINEAR_CUTLASS_TEST_PARAMS
+)
+def test_rowwise_scaled_linear_cutlass_s4s4(dtype, batch_size, size_mnk, use_bias):
+    size_m, size_n, size_k = size_mnk
+
+    input = torch.randn((batch_size, size_m, size_k), dtype=dtype, device="cuda")
+    weight = torch.rand((size_n, size_k), dtype=dtype, device="cuda")
+    bias = torch.rand((size_n,), dtype=dtype, device="cuda") if use_bias else None
+
+    input_2d = input.view(-1, input.shape[-1])
+    input_2d_s8, input_2d_scales, input_2d_zeros = group_quantize_tensor_symmetric(
+        input_2d, 4, size_k, dtype
+    )
+    assert torch.all(input_2d_zeros == 0)
+    input_s8 = input_2d_s8.reshape(input.shape)
+    input_s4 = ((input_s8[:, :, 1::2] & 0xF) << 4) | (input_s8[:, :, 0::2] & 0xF)
+    input_scales = input_2d_scales.reshape(input.shape[:-1])
+
+    weight_s8, weight_scales, weight_zeros = group_quantize_tensor_symmetric(
+        weight, 4, size_n, dtype
+    )
+    assert torch.all(weight_zeros == 0)
+    weight_s4 = ((weight_s8[:, 1::2] & 0xF) << 4) | (weight_s8[:, 0::2] & 0xF)
+
+    # If torch.nn.functional.linear(input, weight, bias) used as
+    # reference, the error would be too big.  The calculation below is
+    # approximately what rowwise_scaled_linear_cutlass kernel is doing
+    # (except that matrix multiplication is over integers there)).
+    size_m_2d = input_2d.shape[0]
+    output_ref = (
+        (input_2d_s8.to(dtype) @ weight_s8.to(dtype).T)
+        * input_2d_scales.view(size_m_2d, 1)
+        * weight_scales.view(1, size_n)
+    )
+    if bias is not None:
+        output_ref += bias
+    output_ref = output_ref.reshape(input.shape[:-1] + (size_n,))
+
+    fn_inputs = (input_s4, input_scales, weight_s4, weight_scales, bias)
+    try:
+        output = rowwise_scaled_linear_cutlass_s4s4(*fn_inputs)
+    except NotImplementedError:
+        pytest.xfail("rowwise_scaled_linear_cutlass() op not implemented")
+
+    max_diff = compute_max_diff(output, output_ref)
+    assert max_diff < 1e-3
+
[email protected](not torch.cuda.is_available(), reason="CUDA not available")
[email protected](
+    "dtype, batch_size, size_mnk, use_bias", ROWWISE_SCALED_LINEAR_CUTLASS_TEST_PARAMS
+)
+def test_rowwise_scaled_linear_cutlass_s8s4(dtype, batch_size, size_mnk, use_bias):
+    size_m, size_n, size_k = size_mnk
+
+    input = torch.randn((batch_size, size_m, size_k), dtype=dtype, device="cuda")
+    weight = torch.rand((size_n, size_k), dtype=dtype, device="cuda")
+    bias = torch.rand((size_n,), dtype=dtype, device="cuda") if use_bias else None
+
+    input_2d = input.view(-1, input.shape[-1])
+    input_2d_s8, input_2d_scales, input_2d_zeros = group_quantize_tensor_symmetric(
+        input_2d, 8, size_k, dtype
+    )
+    assert torch.all(input_2d_zeros == 0)
+    input_s8 = input_2d_s8.reshape(input.shape)
+    input_scales = input_2d_scales.reshape(input.shape[:-1])
+
+    weight_s8, weight_scales, weight_zeros = group_quantize_tensor_symmetric(
+        weight, 4, size_n, dtype
+    )
+    assert torch.all(weight_zeros == 0)
+    weight_s4 = ((weight_s8[:, 1::2] & 0xF) << 4) | (weight_s8[:, 0::2] & 0xF)
+
+    # If torch.nn.functional.linear(input, weight, bias) used as
+    # reference, the error would be too big.  The calculation below is
+    # approximately what rowwise_scaled_linear_cutlass kernel is doing
+    # (except that matrix multiplication is over integers there)).
+    size_m_2d = input_2d.shape[0]
+    output_ref = (
+        (input_2d_s8.to(dtype) @ weight_s8.to(dtype).T)
+        * input_2d_scales.view(size_m_2d, 1)
+        * weight_scales.view(1, size_n)
+    )
+    if bias is not None:
+        output_ref += bias
+    output_ref = output_ref.reshape(input.shape[:-1] + (size_n,))
+
+    fn_inputs = (input_s8, input_scales, weight_s4, weight_scales, bias)
+    try:
+        output = rowwise_scaled_linear_cutlass_s8s4(*fn_inputs)
+    except NotImplementedError:
+        pytest.xfail("rowwise_scaled_linear_cutlass() op not implemented")
+
+    max_diff = compute_max_diff(output, output_ref)
+    assert max_diff < 5e-3
diff --git a/test/test_s8s4_linear_cutlass.py b/test/test_s8s4_linear_cutlass.py
deleted file mode 100644
index 6510ada..0000000
--- a/test/test_s8s4_linear_cutlass.py
+++ /dev/null
@@ -1,77 +0,0 @@
-import itertools
-
-import pytest
-import torch
-
-from torchao.ops import s8s4_linear_cutlass
-from torchao.quantization.utils import group_quantize_tensor_symmetric
-from torchao.utils import compute_max_diff
-
-S8S4_LINEAR_CUTLASS_DTYPE = [torch.float16, torch.bfloat16]
-S8S4_LINEAR_CUTLASS_BATCH_SIZE = [1, 4, 8, 16, 32, 64]
-S8S4_LINEAR_CUTLASS_SIZE_MNK = [
-    (2, 512, 128),
-    (3, 2048, 2048),
-    (4, 3584, 640),
-    (13, 8704, 8576),
-    (26, 18944, 1664),
-    (67, 6656, 1408),
-]
-S8S4_LINEAR_CUTLASS_USE_BIAS = [False, True]
-S8S4_LINEAR_CUTLASS_TEST_PARAMS = list(
-    itertools.product(
-        S8S4_LINEAR_CUTLASS_DTYPE,
-        S8S4_LINEAR_CUTLASS_BATCH_SIZE,
-        S8S4_LINEAR_CUTLASS_SIZE_MNK,
-        S8S4_LINEAR_CUTLASS_USE_BIAS,
-    )
-)
-
-
[email protected](not torch.cuda.is_available(), reason="CUDA not available")
[email protected](
-    "dtype, batch_size, size_mnk, use_bias", S8S4_LINEAR_CUTLASS_TEST_PARAMS
-)
-def test_s8s4_linear_cutlass(dtype, batch_size, size_mnk, use_bias):
-    size_m, size_n, size_k = size_mnk
-
-    input = torch.randn((batch_size, size_m, size_k), dtype=dtype, device="cuda")
-    weight = torch.rand((size_n, size_k), dtype=dtype, device="cuda")
-    bias = torch.rand((size_n,), dtype=dtype, device="cuda") if use_bias else None
-
-    input_2d = input.view(-1, input.shape[-1])
-    input_2d_s8, input_2d_scales, input_2d_zeros = group_quantize_tensor_symmetric(
-        input_2d, 8, size_k, dtype
-    )
-    assert torch.all(input_2d_zeros == 0)
-    input_s8 = input_2d_s8.reshape(input.shape)
-    input_scales = input_2d_scales.reshape(input.shape[:-1])
-
-    weight_s8, weight_scales, weight_zeros = group_quantize_tensor_symmetric(
-        weight, 4, size_n, dtype
-    )
-    assert torch.all(weight_zeros == 0)
-    weight_s4 = ((weight_s8[:, 1::2] & 0xF) << 4) | (weight_s8[:, 0::2] & 0xF)
-
-    # If torch.nn.functional.linear(input, weight, bias) used as
-    # reference, the error would be too big.  The calculation below is
-    # approximately what s8s4_linear_cutlass kernel is doing (except
-    # that matrrix multiplication is over integers there)).
-    size_m_2d = input_2d.shape[0]
-    output_ref = (
-        (input_2d_s8.to(dtype) @ weight_s8.to(dtype).T)
-        * input_2d_scales.view(size_m_2d, 1)
-        * weight_scales.view(1, size_n)
-    )
-    if bias is not None:
-        output_ref += bias
-    output_ref = output_ref.reshape(input.shape[:-1] + (size_n,))
-
-    fn_inputs = (input_s8, input_scales, weight_s4, weight_scales, bias)
-    try:
-        output = s8s4_linear_cutlass(*fn_inputs)
-    except NotImplementedError:
-        pytest.xfail("s8s4_linear_cutlass() op not implemented")
-
-    max_diff = compute_max_diff(output, output_ref)
-    assert max_diff < 5e-3
diff --git a/torchao/csrc/cuda/int4_cutlass.cu b/torchao/csrc/cuda/int4_cutlass.cu
deleted file mode 100644
index 452abcc..0000000
--- a/torchao/csrc/cuda/int4_cutlass.cu
+++ /dev/null
@@ -1,231 +0,0 @@
-#include <torch/extension.h>
-#include <ATen/cuda/CUDAContext.h>
-
-// copied from s8s4_linear_cutlass.cu
-#if defined(TORCHAO_USE_CUTLASS) && !defined(_WIN32) &&                   \
-    defined(CUDA_VERSION) && (CUDA_VERSION >= 11080)
-#define BUILD_INT4_MM_CUTLASS
-#endif
-
-#if defined(BUILD_INT4_MM_CUTLASS)
-#include "cutlass/cutlass.h"
-#include "cutlass/gemm/device/gemm_universal.h"
-#include "cutlass/gemm/device/gemm.h"
-#include "cutlass/epilogue/threadblock/fusion/visitors.hpp"
-#include "cutlass/gemm/kernel/default_gemm_universal_with_visitor.h"
-#include "cutlass/gemm/device/gemm_universal_adapter.h"
-
-#define CUTLASS_STATUS_CHECK(status)                                      \
-  {                                                                       \
-    TORCH_CHECK(status == cutlass::Status::kSuccess,                      \
-                __func__, " : Got CUTLASS error: ",                       \
-                cutlassGetStatusString(status));                          \
-  }
-#endif
-
-namespace torchao {
-
-#if defined(BUILD_INT4_MM_CUTLASS)
-// define common params
-using ElementA           = cutlass::int4b_t;
-using ElementB           = cutlass::int4b_t;
-using ElementAccumulator = int32_t;
-using OpClass            = cutlass::arch::OpClassTensorOp;
-using ArchTag            = cutlass::arch::Sm80;
-
-// how many elements to load at a time -> load 128-bit = 32 x 4-bit
-constexpr int AlignmentA = 128 / cutlass::sizeof_bits<ElementA>::value;
-constexpr int AlignmentB = 128 / cutlass::sizeof_bits<ElementB>::value;
-#endif
-
-// we will do input checks in python. A and B are stored as int8
-torch::Tensor int4_mm_cutlass(torch::Tensor A, torch::Tensor B) {
-#if defined(BUILD_INT4_MM_CUTLASS)
-  int M = A.size(0);
-  int K = A.size(1) * 2;
-  int N = B.size(1);
-  torch::Tensor C = torch::empty({M, N}, A.options().dtype(torch::kInt32));
-
-  // some configs for int4 mma
-  // https://github.com/NVIDIA/cutlass/blob/v3.5.1/test/unit/gemm/device/gemm_s4t_s4n_s32t_tensor_op_s32_sm80.cu
-  // using default config. this can be tuned.
-  using ThreadblockShape = cutlass::gemm::GemmShape<128, 256, 128>;
-  using WarpShape        = cutlass::gemm::GemmShape<64, 64, 128>;
-  using InstructionShape = cutlass::gemm::GemmShape<16, 8, 64>;
-  // static int const kStages = 3;
-  using ElementC = int32_t;
-  using Gemm = cutlass::gemm::device::Gemm<
-    ElementA, cutlass::layout::RowMajor,    // A matrix
-    ElementB, cutlass::layout::ColumnMajor, // B matrix
-    ElementC, cutlass::layout::RowMajor,    // C matrix
-    ElementAccumulator, OpClass, ArchTag,
-    ThreadblockShape, WarpShape, InstructionShape
-  >;
-  Gemm::Arguments args {
-    {M, N, K},
-    {reinterpret_cast<ElementA *>(A.data_ptr<int8_t>()), K},
-    {reinterpret_cast<ElementB *>(B.data_ptr<int8_t>()), K},
-    {C.data_ptr<ElementC>(), N},
-    {C.data_ptr<ElementC>(), N},
-    {1, 0}  // epilogue
-  };
-  Gemm gemm_op;
-  CUTLASS_STATUS_CHECK(gemm_op(args));
-  return C;
-#else
-  TORCH_CHECK_NOT_IMPLEMENTED(false, __func__);
-  return at::Tensor{};
-#endif
-}
-
-template<
-  typename ElementC,
-  typename ThreadblockShape,
-  typename WarpShape,
-  typename InstructionShape,
-  int numStages>
-void scaled_int4_mm_cutlass_dispatch(torch::Tensor A, torch::Tensor B, torch::Tensor row_scale, torch::Tensor col_scale, torch::Tensor C) {
-  // problem shape
-  int M = A.size(0);
-  int K = A.size(1) * 2;
-  int N = B.size(1);
-
-  constexpr int AlignmentC = 128 / cutlass::sizeof_bits<ElementC>::value;  // 8 for BF16/FP16
-  using ElementEpilogue = float;
-  constexpr int numEpilogueStages = 1;
-
-  // build epilogue visitor tree
-  using OutputTileThreadMap = cutlass::epilogue::threadblock::OutputTileThreadLayout<
-    ThreadblockShape, WarpShape, ElementC, AlignmentC, numEpilogueStages
-  >;
-
-  using Accum = cutlass::epilogue::threadblock::VisitorAccFetch;
-  constexpr auto RoundMode = cutlass::FloatRoundStyle::round_to_nearest;
-  using Multiply = cutlass::epilogue::threadblock::VisitorCompute<
-    cutlass::multiplies, ElementEpilogue, ElementEpilogue, RoundMode
-  >;
-
-  // (1, N)
-  using ColScale = cutlass::epilogue::threadblock::VisitorRowBroadcast<
-    OutputTileThreadMap, ElementC,
-    cute::Stride<cute::_0, cute::_1, int32_t>  // MNL
-  >;
-  using EVTCompute0 = cutlass::epilogue::threadblock::Sm80EVT<Multiply, Accum, ColScale>;
-
-  // (M, 1)
-  using RowScale = cutlass::epilogue::threadblock::VisitorColBroadcast<
-    OutputTileThreadMap, ElementC,
-    cute::Stride<cute::_1, cute::_0, int32_t>  // MNL
-  >;
-  using EVTCompute1 = cutlass::epilogue::threadblock::Sm80EVT<Multiply, EVTCompute0, RowScale>;
-
-  using Output = cutlass::epilogue::threadblock::VisitorAuxStore<
-    OutputTileThreadMap, ElementC, RoundMode,
-    cute::Stride<int64_t, cute::_1, int64_t>  // MNL
-  >;
-  using EVTOutput = cutlass::epilogue::threadblock::Sm80EVT<Output, EVTCompute1>;
-
-  using EVTKernel = typename cutlass::gemm::kernel::DefaultGemmWithVisitor<
-    ElementA, cutlass::layout::RowMajor,    cutlass::ComplexTransform::kNone, AlignmentA,
-    ElementB, cutlass::layout::ColumnMajor, cutlass::ComplexTransform::kNone, AlignmentB,
-    ElementC, cutlass::layout::RowMajor,                                      AlignmentC,
-    ElementAccumulator, ElementEpilogue, OpClass, ArchTag,
-    ThreadblockShape, WarpShape, InstructionShape,
-    EVTOutput,
-    cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>,
-    numStages,
-    cutlass::arch::OpMultiplyAddSaturate,  // OpMultiplyAdd does not work
-    numEpilogueStages
-  >::GemmKernel;
-  using DeviceGemm = cutlass::gemm::device::GemmUniversalAdapter<EVTKernel>;
-
-  // col_scale, row_scale, and C must have the same dtype
-  const ElementA *A_ptr         = reinterpret_cast<ElementA *>(A.data_ptr<int8_t>());
-  const ElementB *B_ptr         = reinterpret_cast<ElementB *>(B.data_ptr<int8_t>());
-  const ElementC *col_scale_ptr = reinterpret_cast<ElementC *>(col_scale.data_ptr());
-  const ElementC *row_scale_ptr = reinterpret_cast<ElementC *>(row_scale.data_ptr());
-  ElementC *C_ptr               = reinterpret_cast<ElementC *>(C.data_ptr());
-
-  typename EVTOutput::Arguments callback_args{
-    {
-      {
-        {},                                                                  // Accum
-        {col_scale_ptr, ElementC(0), {cute::_0{}, cute::_1{}, int32_t(N)}},  // ColScale
-        {}                                                                   // Multiply
-      },                                                                     // EVTCompute0
-      {row_scale_ptr, ElementC(0), {cute::_1{}, cute::_0{}, int32_t(M)}},    // RowScale
-      {}                                                                     // Multiply
-    },                                                                       // EVTCompute1
-    {C_ptr, {int64_t{N}, cute::_1{}, int64_t{M*N}}}                          // EVTOutput
-  };
-
-  typename DeviceGemm::Arguments args(
-    cutlass::gemm::GemmUniversalMode::kGemm,
-    cutlass::gemm::GemmCoord{M, N, K},
-    1,                              // batch_split
-    callback_args,
-    A_ptr, B_ptr, nullptr, nullptr, // unsued C_ptr and D_ptr
-    M * K, N * K, 0, 0,             // batch_stride A, B, C, D
-    K, K, 0, 0                      // stride A, B, C, D
-  );
-
-  DeviceGemm gemm_op;
-  auto stream = at::cuda::getCurrentCUDAStream();
-  CUTLASS_STATUS_CHECK(gemm_op.can_implement(args));
-  CUTLASS_STATUS_CHECK(gemm_op(args, nullptr, stream));
-}
-
-// we will do input checks in python. A and B are stored as int8
-// this function is based on the following cutlass example
-// https://github.com/NVIDIA/cutlass/blob/main/examples/47_ampere_gemm_universal_streamk/ampere_gemm_universal_streamk_broadcast.cu
-// also with the help of emitted code from cutlass Python
-torch::Tensor scaled_int4_mm_cutlass(torch::Tensor A, torch::Tensor B, torch::Tensor row_scale, torch::Tensor col_scale) {
-#if defined(BUILD_INT4_MM_CUTLASS)
-  int M = A.size(0);
-  int N = B.size(1);
-  torch::Tensor C = torch::empty({M, N}, row_scale.options());
-
-  // some configs for int4 mma
-  // https://github.com/NVIDIA/cutlass/blob/v3.5.1/test/unit/gemm/device/gemm_s4t_s4n_s32t_tensor_op_s32_sm80.cu
-  // using default config. this can be tuned.
-  using ThreadblockShape = cutlass::gemm::GemmShape<128, 256, 128>;
-  using WarpShape        = cutlass::gemm::GemmShape<64, 64, 128>;
-  using InstructionShape = cutlass::gemm::GemmShape<16, 8, 64>;
-  constexpr int numStages = 3;
-
-  AT_DISPATCH_SWITCH(
-    row_scale.scalar_type(),
-    "scaled_int4_mm_cutlass",
-    AT_DISPATCH_CASE(
-      torch::ScalarType::Half,
-      [&]() {
-        using ElementC = cutlass::half_t;
-        scaled_int4_mm_cutlass_dispatch<
-          ElementC, ThreadblockShape, WarpShape, InstructionShape, numStages>(
-            A, B, row_scale, col_scale, C);
-      }
-    )
-    AT_DISPATCH_CASE(
-      torch::ScalarType::BFloat16,
-      [&]() {
-        using ElementC = cutlass::bfloat16_t;
-        scaled_int4_mm_cutlass_dispatch<
-          ElementC, ThreadblockShape, WarpShape, InstructionShape, numStages>(
-            A, B, row_scale, col_scale, C);
-      }
-    )
-  );
-
-  return C;
-#else
-  TORCH_CHECK_NOT_IMPLEMENTED(false, __func__);
-  return at::Tensor{};
-#endif
-}
-
-TORCH_LIBRARY_IMPL(torchao, CUDA, m) {
-  m.impl("torchao::int4_mm_cutlass", &int4_mm_cutlass);
-  m.impl("torchao::scaled_int4_mm_cutlass", &scaled_int4_mm_cutlass);
-}
-
-}  // namespace torchao
diff --git a/torchao/csrc/cuda/rowwise_scaled_linear_cutlass/rowwise_scaled_linear_cutlass.cuh b/torchao/csrc/cuda/rowwise_scaled_linear_cutlass/rowwise_scaled_linear_cutlass.cuh
new file mode 100644
index 0000000..d1969bc
--- /dev/null
+++ b/torchao/csrc/cuda/rowwise_scaled_linear_cutlass/rowwise_scaled_linear_cutlass.cuh
@@ -0,0 +1,578 @@
+#pragma once
+
+#include <ATen/ATen.h>
+#include <ATen/core/Tensor.h>
+#include <ATen/cuda/CUDAUtils.h>
+#include <c10/util/Exception.h>
+
+#if defined(TORCHAO_USE_CUTLASS) && !defined(_WIN32) &&                   \
+    defined(CUDA_VERSION) && (CUDA_VERSION >= 11080)
+#define BUILD_ROWWISE_SCALED_LINEAR_CUTLASS
+#endif
+
+#if defined(BUILD_ROWWISE_SCALED_LINEAR_CUTLASS)
+#include <cuda_runtime.h>
+#include <cutlass/cutlass.h>
+#include <cutlass/gemm/device/gemm_universal.h>
+#include <cutlass/epilogue/threadblock/fusion/visitors.hpp>
+#include <cutlass/gemm/kernel/default_gemm_universal_with_visitor.h>
+#include <cutlass/gemm/device/gemm_universal_adapter.h>
+
+#define CUTLASS_STATUS_CHECK(status)                                      \
+  {                                                                       \
+    TORCH_CHECK(status == cutlass::Status::kSuccess,                      \
+                __func__, " : Got CUTLASS error: ",                       \
+                cutlassGetStatusString(status));                          \
+  }
+#endif
+
+namespace torchao {
+
+#if defined(BUILD_ROWWISE_SCALED_LINEAR_CUTLASS)
+template<
+    typename ThreadblockShape,
+    typename WarpShape,
+    typename InstructionShape,
+    typename ThreadblockSwizzle,
+    int NumStages,
+    typename ElementA,
+    typename ElementB,
+    typename ElementOutput,
+    typename ElementC,
+    typename UseTensorC,
+    typename ElementAScale,
+    typename ElementBScale>
+void rowwise_scaled_linear_kernel_cutlass_sm8x(
+    const at::Tensor& tensor_a, const at::Tensor& tensor_a_scale,
+    const at::Tensor& tensor_b, const at::Tensor& tensor_b_scale,
+    const at::Tensor& tensor_c, at::Tensor& tensor_d) {
+  static_assert((cutlass::sizeof_bits<ElementA>::value >= 8 ||
+                 8 % cutlass::sizeof_bits<ElementA>::value == 0) &&
+                (cutlass::sizeof_bits<ElementB>::value >= 8 ||
+                 8 % cutlass::sizeof_bits<ElementB>::value == 0));
+
+  using SmArch = cutlass::arch::Sm80;
+
+  using LayoutA = cutlass::layout::RowMajor;
+  using LayoutB = cutlass::layout::ColumnMajor;
+  using LayoutOutput = cutlass::layout::RowMajor;
+
+  using ElementAccumulator = int32_t;
+  using Operator =
+      std::conditional_t<std::is_same<ElementA, ElementB>::value,
+                         cutlass::arch::OpMultiplyAddSaturate,
+                         cutlass::arch::OpMultiplyAddMixedInputUpcast>;
+
+  using ElementEpilogue = float;
+
+  constexpr auto NumEVTEpilogueStages = 1;
+
+  const int m = tensor_a.size(0);
+  const int n = tensor_b.size(0);
+  int k = tensor_a.size(1);
+  if constexpr (cutlass::sizeof_bits<ElementA>::value < 8) {
+    k *= 8 % cutlass::sizeof_bits<ElementA>::value;
+  }
+
+  constexpr int AlignmentA = 128 / cutlass::sizeof_bits<ElementA>::value;
+  constexpr int AlignmentAScale =
+      128 / cutlass::sizeof_bits<ElementAScale>::value;
+  constexpr int AlignmentB = 128 / cutlass::sizeof_bits<ElementB>::value;
+  constexpr int AlignmentBScale =
+      128 / cutlass::sizeof_bits<ElementBScale>::value;
+  constexpr int AlignmentC = 128 / cutlass::sizeof_bits<ElementC>::value;
+  constexpr int AlignmentOutput =
+      128 / cutlass::sizeof_bits<ElementOutput>::value;
+
+  // Check for current CUTLASS limitations w.r.t. alignments.
+  TORCH_CHECK(k % AlignmentA == 0,
+              __func__, " : Number of columns of tensor A must be divisible ",
+              "by ", AlignmentA);
+  TORCH_CHECK(k % AlignmentB == 0,
+              __func__, " : Number of columns of tensor B must be divisible ",
+              "by ", AlignmentB);
+  TORCH_CHECK(n % AlignmentC == 0,
+              __func__, " : Number of columns of tensor C must be divisible ",
+              "by ", AlignmentC);
+
+  using TensorAScaleTileThreadMap =
+      cutlass::epilogue::threadblock::OutputTileThreadLayout<
+          ThreadblockShape,
+          WarpShape,
+          ElementAScale,
+          AlignmentAScale,
+          NumEVTEpilogueStages>;
+  using TensorBScaleTileThreadMap =
+      cutlass::epilogue::threadblock::OutputTileThreadLayout<
+          ThreadblockShape,
+          WarpShape,
+          ElementBScale,
+          AlignmentBScale,
+          NumEVTEpilogueStages>;
+  using TensorCTileThreadMap =
+      cutlass::epilogue::threadblock::OutputTileThreadLayout<
+          ThreadblockShape,
+          WarpShape,
+          ElementC,
+          AlignmentC,
+          NumEVTEpilogueStages>;
+  using OutputTileThreadMap =
+      cutlass::epilogue::threadblock::OutputTileThreadLayout<
+          ThreadblockShape,
+          WarpShape,
+          ElementOutput,
+          AlignmentOutput,
+          NumEVTEpilogueStages>;
+
+  using Accum = cutlass::epilogue::threadblock::VisitorAccFetch;
+
+  using TensorAScale =
+      cutlass::epilogue::threadblock::VisitorColBroadcast<
+          TensorAScaleTileThreadMap,
+          ElementAScale,
+          cute::Stride<cute::_1, cute::_0, int64_t>>;
+  using TensorAScaleArguments = typename TensorAScale::Arguments;
+
+  using TensorBScale =
+      cutlass::epilogue::threadblock::VisitorRowBroadcast<
+          TensorBScaleTileThreadMap,
+          ElementBScale,
+          cute::Stride<cute::_0, cute::_1, int64_t>>;
+  using TensorBScaleArguments = typename TensorBScale::Arguments;
+
+  using TensorCScalar =
+      cutlass::epilogue::threadblock::VisitorScalarBroadcast<ElementC>;
+  using TensorCTensor =
+      cutlass::epilogue::threadblock::VisitorRowBroadcast<
+          TensorCTileThreadMap,
+          ElementC,
+          cute::Stride<cute::_0, cute::_1, int64_t>>;
+  using TensorC =
+      std::conditional_t<UseTensorC::value, TensorCTensor, TensorCScalar>;
+  using TensorCArguments = typename TensorC::Arguments;
+
+  using ApplyAScale = cutlass::epilogue::threadblock::VisitorCompute<
+      cutlass::multiplies, ElementEpilogue, ElementEpilogue,
+      cutlass::FloatRoundStyle::round_to_nearest
+  >;
+  using EVTApplyAScale = cutlass::epilogue::threadblock::Sm80EVT<
+      ApplyAScale,
+      Accum,
+      TensorAScale>;
+
+  using ApplyBScale = cutlass::epilogue::threadblock::VisitorCompute<
+      cutlass::multiplies, ElementEpilogue, ElementEpilogue,
+      cutlass::FloatRoundStyle::round_to_nearest
+  >;
+  using EVTApplyBScale = cutlass::epilogue::threadblock::Sm80EVT<
+      ApplyBScale,
+      EVTApplyAScale,
+      TensorBScale>;
+
+  using ApplySum = cutlass::epilogue::threadblock::VisitorCompute<
+      cutlass::plus, ElementEpilogue, ElementEpilogue,
+      cutlass::FloatRoundStyle::round_to_nearest
+  >;
+  using EVTApplySum = cutlass::epilogue::threadblock::Sm80EVT<
+      ApplySum,
+      EVTApplyBScale,
+      TensorC>;
+
+  using Output = cutlass::epilogue::threadblock::VisitorAuxStore<
+      OutputTileThreadMap, ElementOutput,
+      cutlass::FloatRoundStyle::round_to_nearest,
+      cute::Stride<int64_t, cute::_1, int64_t> // StrideMNL
+  >;
+
+  using EVTOutput = cutlass::epilogue::threadblock::Sm80EVT<
+      Output,
+      EVTApplySum>;
+
+  using EVTKernel =
+      typename cutlass::gemm::kernel::DefaultGemmWithVisitor<
+      ElementA, LayoutA, cutlass::ComplexTransform::kNone, AlignmentA,
+      ElementB, LayoutB, cutlass::ComplexTransform::kNone, AlignmentB,
+      ElementOutput, LayoutOutput, AlignmentOutput,
+      ElementAccumulator,
+      ElementEpilogue,
+      cutlass::arch::OpClassTensorOp,
+      SmArch,
+      ThreadblockShape,
+      WarpShape,
+      InstructionShape,
+      EVTOutput,
+      ThreadblockSwizzle,
+      NumStages,
+      Operator,
+      NumEVTEpilogueStages
+  >::GemmKernel;
+
+  using Gemm = cutlass::gemm::device::GemmUniversalAdapter<EVTKernel>;
+
+  cutlass::gemm::GemmCoord problem_size(m, n, k);
+  constexpr auto SplitKFactor = 1;
+
+  TensorAScaleArguments tensor_a_scale_arguments{
+    (ElementAScale*)tensor_a_scale.data_ptr(),
+    ElementAScale(1),
+    {cute::_1{}, cute::_0{}, problem_size.m()}
+  };
+  TensorBScaleArguments tensor_b_scale_arguments{
+    (ElementBScale*)tensor_b_scale.data_ptr(),
+    ElementBScale(1),
+    {cute::_0{}, cute::_1{}, problem_size.n()}
+  };
+  TensorCArguments tensor_c_arguments{
+    [&]() -> TensorCArguments {
+      if constexpr (UseTensorC::value) {
+        return {(ElementC*)tensor_c.data_ptr(),
+                ElementC(0),
+                {cute::_0{}, cute::_1{}, problem_size.n()}};
+      } else {
+        return {ElementC(0)};
+      }
+    }()
+  };
+  typename Output::Arguments output_arguments{
+    (ElementOutput*)tensor_d.data_ptr(),
+    {problem_size.n(), cute::_1{}, problem_size.mn().product()}
+  };
+  typename EVTOutput::Arguments callback_arguments{
+    {
+      {
+        {
+          {},                        // Accum
+          tensor_a_scale_arguments,  // TensorAScale
+          {}                         // ApplyAScale
+        },                           // EVTApplyAScale
+        tensor_b_scale_arguments,    // TensorBScale
+        {},                          // ApplyBScale
+      },                             // EVTApplyBScale
+      tensor_c_arguments,            // TensorC
+      {}                             // ApplySum
+    },                               // EVTApplySum
+    output_arguments                 // Output
+  };                                 // EVTOutput
+
+  typename Gemm::Arguments arguments(
+    cutlass::gemm::GemmUniversalMode::kGemm,
+    problem_size,
+    SplitKFactor,
+    callback_arguments,              // arguments of EVT callbacks
+    (ElementA*)tensor_a.data_ptr(),
+    (ElementB*)tensor_b.data_ptr(),
+    nullptr,                         // ptr C (unused)
+    nullptr,                         // ptr D (unused)
+    problem_size.mk().product(),     // batch stride A
+    problem_size.nk().product(),     // batch stride B
+    0,                               // batch stride C (unused)
+    0,                               // batch stride D (unused)
+    problem_size.k(),                // stride A
+    problem_size.k(),                // stride B
+    0,                               // stride C (unused)
+    0                                // stride D (unused)
+  );
+
+  Gemm gemm_op;
+
+  cutlass::Status status;
+
+  // Verify that GEMM operation with given arguments can be performed
+  // by CUTLASS.
+  status = gemm_op.can_implement(arguments);
+  CUTLASS_STATUS_CHECK(status);
+
+  // Allocate workspace for CUTLASS mixed datatypes GEMM kernel.
+  const auto workspace_size = Gemm::get_workspace_size(arguments);
+  auto workspace = tensor_a.new_empty({(int64_t)workspace_size},
+                                      at::TensorOptions().dtype(at::kByte));
+
+  // Initialize CUTLASS mixed datatypes GEMM object.
+  status = gemm_op.initialize(arguments, workspace.data_ptr(),
+                              at::cuda::getCurrentCUDAStream());
+  CUTLASS_STATUS_CHECK(status);
+
+  // Perform mixed datatypes GEMM operation.
+  status = gemm_op.run(at::cuda::getCurrentCUDAStream());
+  CUTLASS_STATUS_CHECK(status);
+
+  C10_CUDA_KERNEL_LAUNCH_CHECK();
+}
+
+template<typename ElementA, typename ElementB, typename... Types>
+static void select_config(
+    const at::Tensor& tensor_a, const at::Tensor& tensor_a_scale,
+    const at::Tensor& tensor_b, const at::Tensor& tensor_b_scale,
+    const at::Tensor& tensor_c, at::Tensor& tensor_d) { 
+  const auto dprops = at::cuda::getCurrentDeviceProperties();
+  const auto is_sm8x = dprops->major == 8;
+
+  if (is_sm8x) {
+    if constexpr (std::is_same<ElementA, cutlass::int4b_t>::value &&
+                  std::is_same<ElementB, cutlass::int4b_t>::value) {
+      using ThreadblockShape = cutlass::gemm::GemmShape<128, 256, 128>;
+      using WarpShape = cutlass::gemm::GemmShape<64, 64, 128>;
+      using InstructionShape = cutlass::gemm::GemmShape<16, 8, 64>;
+      using ThreadblockSwizzle =
+          cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>;
+      constexpr auto NumStages = 3;
+      rowwise_scaled_linear_kernel_cutlass_sm8x<
+        ThreadblockShape, WarpShape, InstructionShape, ThreadblockSwizzle,
+          NumStages, ElementA, ElementB, Types...>(
+            tensor_a, tensor_a_scale, tensor_b, tensor_b_scale, tensor_c,
+            tensor_d);
+      return;
+    } else if constexpr (std::is_same<ElementA, int8_t>::value &&
+                  std::is_same<ElementB, cutlass::int4b_t>::value) {
+      using InstructionShape = cutlass::gemm::GemmShape<16, 8, 32>;
+      using ThreadblockSwizzle =
+        cutlass::gemm::threadblock::ThreadblockSwizzleStreamK;
+
+      // A minimal heuristic to improve performance for small number
+      // of inputs cases.
+      if (tensor_a.size(0) <= 16) {
+        using ThreadblockShape = cutlass::gemm::GemmShape<16, 128, 128>;
+        using WarpShape = cutlass::gemm::GemmShape<16, 32, 128>;
+        constexpr auto NumStages = 6;
+        rowwise_scaled_linear_kernel_cutlass_sm8x<
+          ThreadblockShape, WarpShape, InstructionShape, ThreadblockSwizzle,
+          NumStages, ElementA, ElementB, Types...>(
+              tensor_a, tensor_a_scale, tensor_b, tensor_b_scale, tensor_c,
+              tensor_d);
+      } else if (tensor_a.size(0) <= 32) {
+        using ThreadblockShape = cutlass::gemm::GemmShape<32, 128, 128>;
+        using WarpShape = cutlass::gemm::GemmShape<32, 32, 128>;
+        constexpr auto NumStages = 5;
+        rowwise_scaled_linear_kernel_cutlass_sm8x<
+          ThreadblockShape, WarpShape, InstructionShape, ThreadblockSwizzle,
+          NumStages, ElementA, ElementB, Types...>(
+              tensor_a, tensor_a_scale, tensor_b, tensor_b_scale, tensor_c,
+              tensor_d);
+      } else {
+        using ThreadblockShape = cutlass::gemm::GemmShape<64, 128, 128>;
+        using WarpShape = cutlass::gemm::GemmShape<64, 32, 128>;
+        constexpr auto NumStages = 4;
+        rowwise_scaled_linear_kernel_cutlass_sm8x<
+          ThreadblockShape, WarpShape, InstructionShape, ThreadblockSwizzle,
+          NumStages, ElementA, ElementB, Types...>(
+              tensor_a, tensor_a_scale, tensor_b, tensor_b_scale, tensor_c,
+              tensor_d);
+      }
+      return;
+    }
+  }
+
+  TORCH_CHECK(false,
+              __func__, " : Operator not supported on SM", dprops->major, ".",
+              dprops->minor, " for given operands");
+}
+
+template<
+   typename ElementA,
+   typename ElementB,
+   typename ElementOutput,
+   typename... Types>
+static void
+dispatch_on_tensor_c(
+    const at::Tensor& tensor_a, const at::Tensor& tensor_a_scale,
+    const at::Tensor& tensor_b, const at::Tensor& tensor_b_scale,
+    const at::Tensor& tensor_c, at::Tensor& tensor_d) {
+  if (tensor_c.numel() == 0) {
+    using ElementC = ElementOutput;
+    using UseTensorC = std::false_type;
+    select_config<
+      ElementA, ElementB, ElementOutput, ElementC, UseTensorC, Types...>(
+          tensor_a, tensor_a_scale, tensor_b, tensor_b_scale, tensor_c,
+          tensor_d);
+    return;
+  }
+
+  using UseTensorC = std::true_type;
+  if (tensor_c.scalar_type() == at::ScalarType::Half) {
+    using ElementC = cutlass::half_t;
+    select_config<
+      ElementA, ElementB, ElementOutput, ElementC, UseTensorC, Types...>(
+          tensor_a, tensor_a_scale, tensor_b, tensor_b_scale, tensor_c,
+          tensor_d);
+    return;
+  } else if (tensor_c.scalar_type() == at::ScalarType::BFloat16) {
+    using ElementC = cutlass::bfloat16_t;
+    select_config<
+      ElementA, ElementB, ElementOutput, ElementC, UseTensorC, Types...>(
+          tensor_a, tensor_a_scale, tensor_b, tensor_b_scale, tensor_c,
+          tensor_d);
+    return;
+  }
+
+  TORCH_CHECK(false,
+              __func__, " : Operator not supported for datatype ",
+                tensor_c.scalar_type(), " for addend");
+}
+
+template<typename ElementA, typename ElementB>
+static void
+dispatch_on_tensor_a_scale_and_tensor_b_scale(
+    const at::Tensor& tensor_a, const at::Tensor& tensor_a_scale,
+    const at::Tensor& tensor_b, const at::Tensor& tensor_b_scale,
+    const at::Tensor& tensor_c, at::Tensor& tensor_d) {
+  TORCH_CHECK(tensor_d.scalar_type() == tensor_a_scale.scalar_type(),
+              __func__, " : Operator not supported for output datatype ",
+              tensor_d.scalar_type(), " as it's different from the first ",
+              " operand scale datatype ", tensor_a_scale.scalar_type());
+
+  if (tensor_a_scale.scalar_type() == at::ScalarType::Half &&
+      tensor_b_scale.scalar_type() == at::ScalarType::Half) {
+    using ElementAScale = cutlass::half_t;
+    using ElementBScale = cutlass::half_t;
+    using ElementOutput = cutlass::half_t;
+    dispatch_on_tensor_c<ElementA, ElementB, ElementOutput, ElementAScale,
+      ElementBScale>(
+        tensor_a, tensor_a_scale, tensor_b, tensor_b_scale, tensor_c, tensor_d);
+    return;
+  } else if (tensor_a_scale.scalar_type() == at::ScalarType::BFloat16 &&
+             tensor_b_scale.scalar_type() == at::ScalarType::BFloat16) {
+    using ElementAScale = cutlass::bfloat16_t;
+    using ElementBScale = cutlass::bfloat16_t;
+    using ElementOutput = cutlass::bfloat16_t;
+    dispatch_on_tensor_c<ElementA, ElementB, ElementOutput, ElementAScale,
+      ElementBScale>(
+        tensor_a, tensor_a_scale, tensor_b, tensor_b_scale, tensor_c, tensor_d);
+    return;
+  }
+
+  TORCH_CHECK(false,
+              __func__, " : Operator not supported for combination of data ",
+              "types ", tensor_a_scale.scalar_type(),
+              " for first operand scale and ", tensor_b_scale.scalar_type(),
+                " for second operand scale");
+}
+
+template<typename ElementA, typename ElementB>
+void
+rowwise_scaled_linear_cutlass_check_inputs(
+    const at::Tensor& xq, const at::Tensor& x_scale, const at::Tensor& wq,
+    const at::Tensor& w_scale, const at::Tensor& bias) {
+  // Validate layouts of arguments.
+  TORCH_CHECK(xq.dim() >= 2,
+              __func__, " : Expected xq argument to be 2D or "
+              "higher-dimensional tensor, got ", xq.dim(), " dims");
+  TORCH_CHECK(xq.layout() == at::Layout::Strided,
+              __func__, " : Expected xq argument to be strided, got layout ",
+              xq.layout());
+  TORCH_CHECK(x_scale.dim() == xq.dim() - 1,
+              __func__, " : Expected xq scale argument to be ", xq.dim() - 1,
+              "D tensor, got ", x_scale.dim(), " dims");
+  TORCH_CHECK(x_scale.layout() == at::Layout::Strided,
+              __func__, " : Expected xq scale argument to be strided, got "
+              "layout ", x_scale.layout());
+  TORCH_CHECK(wq.dim() == 2,
+              __func__, " : Expected wq argument to be 2D tensor, got ",
+              wq.dim(), " dims");
+  TORCH_CHECK(wq.layout() == at::Layout::Strided,
+              __func__, " : Expected wq argument to be strided, got layout ",
+              wq.layout());
+  TORCH_CHECK(w_scale.dim() == 1 || w_scale.dim() == 2,
+              __func__, " : Expected wq scale argument to be 1D or 2D tensor, ",
+              "got ", w_scale.dim(), " dims");
+  TORCH_CHECK(w_scale.layout() == at::Layout::Strided,
+              __func__, " : Expected wq scale argument to be strided, got "
+              "layout ", w_scale.layout());
+  if (bias.numel() > 0) {
+    TORCH_CHECK(bias.dim() == 1,
+                __func__, " : Expected bias argument to be 1D tensor, got ",
+                bias.dim(), " dims");
+    TORCH_CHECK(bias.layout() == at::Layout::Strided,
+                __func__, " : Expected bias argument to be strided, got ",
+                "layout ", bias.layout());
+  }
+
+  // Validate sizes of arguments.
+  const auto xq_sizes = xq.sizes().vec();
+  TORCH_CHECK(xq_sizes.back() == wq.size(1) ||
+              xq_sizes.back() == 2 * wq.size(1),
+              __func__, " : Expected xq argument to have ", wq.size(1), " or ",
+              2 * wq.size(1), " columns, but got ", xq_sizes.back());
+  const auto x_scale_sizes = x_scale.sizes().vec();
+  for (auto i = 0; i < x_scale_sizes.size(); ++i)
+    TORCH_CHECK(x_scale_sizes[i] == xq_sizes[i],
+                __func__, " : Expected xq scale argument size at position ",
+                i, " to be ", xq_sizes[i], ", but got ", x_scale_sizes[i]);
+  TORCH_CHECK(w_scale.numel() == wq.size(0),
+              __func__, " : Expected wq scale argument to have ", wq.size(0),
+              " elements, got ", w_scale.numel(), " elements");
+  if (bias.numel() > 0) {
+    TORCH_CHECK(bias.numel() == wq.size(0),
+                __func__, " : Expected bias argument to have ", wq.size(0),
+                " elements, got ", bias.numel(), " elements");
+  }
+
+  // Validate strides of arguments.
+  const auto xq_strides = xq.strides();
+  TORCH_CHECK(xq_strides[xq_strides.size() - 1] == 1,
+              __func__, " : Expected xq argument in row-major layout");
+  auto xq_stride_expected = xq_strides[xq_strides.size() - 2];
+  for (int i = xq_strides.size() - 3; i >= 0; --i) {
+    xq_stride_expected *= xq_sizes[i + 1];
+    TORCH_CHECK(xq_strides[i] == xq_stride_expected,
+                __func__, " : Expected xq argument in row-major layout");
+  }
+  TORCH_CHECK(x_scale.is_contiguous(),
+              __func__, " : Expected xq scale argument to be contiguous");
+  const auto wq_strides = wq.strides();
+  TORCH_CHECK(wq_strides[0] >= 1 && wq_strides[1] == 1,
+              __func__, " : Expected wq argument in row-major layout");
+  TORCH_CHECK(w_scale.is_contiguous(),
+              __func__, " : Expected wq scale argument to be contiguous");
+  if (bias.numel() > 0) {
+    const auto bias_strides = bias.strides();
+    TORCH_CHECK(bias_strides[0] == 1,
+                __func__, " : Expected bias argument to be contiguous");
+  }
+}
+#endif
+
+// Perform linear operation, using corresponding CUTLASS datatypes
+// GEMM kernel, to given arguments - result produced is:
+// (tensor_a * tensor_a_scale) @ (tensor_b * tensor_b_scale).T + tensor_c
+//
+// Notes: The "tensor_a" and "tensor_b" are expected to be 2D tensors.
+// The "tensor_a_scale" tensor is expected to be a vector, of size
+// equal to number of rows of "tensor_a" tensor.  The "tensor_b_scale"
+// tensor is expected to be a vector, of size equal to number of rows
+// of "tensor_b" tensor. The "tensor_c" tensor is expected to be a
+// vector, of size equal to number of rows of "tensor_b" tensor.
+template <typename ElementA, typename ElementB>
+at::Tensor
+rowwise_scaled_linear_cutlass(
+    const at::Tensor& xq, const at::Tensor& x_scale, const at::Tensor& wq,
+    const at::Tensor& w_scale, const at::Tensor& bias) {
+#if defined(BUILD_ROWWISE_SCALED_LINEAR_CUTLASS)
+  // Check inputs.
+  rowwise_scaled_linear_cutlass_check_inputs<ElementA, ElementB>(
+      xq, x_scale, wq, w_scale, bias);
+
+  // Squash the input tensors as appropriate.
+  const auto xq_sizes = xq.sizes().vec();
+  const auto xq_2d = xq.reshape({-1, xq_sizes.back()});
+  const auto x_scale_1d = x_scale.reshape({-1});
+  const auto w_scale_1d = w_scale.reshape({-1});
+
+  // Create result tensor.
+  at::Tensor result =
+      x_scale.new_empty({xq_2d.size(0), wq.size(0)});
+
+  // Dispatch to appropriate kernel template.
+  dispatch_on_tensor_a_scale_and_tensor_b_scale<ElementA, ElementB>(
+      xq_2d, x_scale_1d, wq, w_scale_1d, bias, result);
+
+  // Reshape and return result tensor.
+  auto result_sizes = xq_sizes;
+  result_sizes.back() = wq.size(0);
+  return result.reshape(result_sizes);
+#else
+  TORCH_CHECK_NOT_IMPLEMENTED(false, __func__);
+  return at::Tensor{};
+#endif
+}
+
+}  // namespace torchao
diff --git a/torchao/csrc/cuda/rowwise_scaled_linear_cutlass/rowwise_scaled_linear_cutlass_s4s4.cu b/torchao/csrc/cuda/rowwise_scaled_linear_cutlass/rowwise_scaled_linear_cutlass_s4s4.cu
new file mode 100644
index 0000000..9a64b2b
--- /dev/null
+++ b/torchao/csrc/cuda/rowwise_scaled_linear_cutlass/rowwise_scaled_linear_cutlass_s4s4.cu
@@ -0,0 +1,28 @@
+#include <torch/extension.h>
+
+#include "rowwise_scaled_linear_cutlass.cuh"
+
+namespace torchao {
+
+at::Tensor
+rowwise_scaled_linear_cutlass_s4s4(
+    const at::Tensor& xq, const at::Tensor& x_scale, const at::Tensor& wq,
+    const at::Tensor& w_scale, const at::Tensor& bias) {
+  // Validate input datatypes.
+  TORCH_CHECK(xq.dtype() == at::kChar && wq.dtype() == at::kChar,
+              __func__, " : The input datatypes combination ", xq.dtype(),
+              " for xq and ", wq.dtype(), " for wq is not supported");
+
+  // Dispatch to appropriate kernel template.
+  using ElementA = cutlass::int4b_t;
+  using ElementB = cutlass::int4b_t;
+  return rowwise_scaled_linear_cutlass<ElementA, ElementB>(
+      xq, x_scale, wq, w_scale, bias);
+}
+
+TORCH_LIBRARY_IMPL(torchao, CUDA, m) {
+  m.impl("torchao::rowwise_scaled_linear_cutlass_s4s4",
+         &rowwise_scaled_linear_cutlass_s4s4);
+}
+
+}  // namespace torchao
diff --git a/torchao/csrc/cuda/rowwise_scaled_linear_cutlass/rowwise_scaled_linear_cutlass_s8s4.cu b/torchao/csrc/cuda/rowwise_scaled_linear_cutlass/rowwise_scaled_linear_cutlass_s8s4.cu
new file mode 100644
index 0000000..752c557
--- /dev/null
+++ b/torchao/csrc/cuda/rowwise_scaled_linear_cutlass/rowwise_scaled_linear_cutlass_s8s4.cu
@@ -0,0 +1,28 @@
+#include <torch/extension.h>
+
+#include "rowwise_scaled_linear_cutlass.cuh"
+
+namespace torchao {
+
+at::Tensor
+rowwise_scaled_linear_cutlass_s8s4(
+    const at::Tensor& xq, const at::Tensor& x_scale, const at::Tensor& wq,
+    const at::Tensor& w_scale, const at::Tensor& bias) {
+  // Validate input datatypes.
+  TORCH_CHECK(xq.dtype() == at::kChar && wq.dtype() == at::kChar,
+              __func__, " : The input datatypes combination ", xq.dtype(),
+              " for xq and ", wq.dtype(), " for wq is not supported");
+
+  // Dispatch to appropriate kernel template.
+  using ElementA = int8_t;
+  using ElementB = cutlass::int4b_t;
+  return rowwise_scaled_linear_cutlass<ElementA, ElementB>(
+      xq, x_scale, wq, w_scale, bias);
+}
+
+TORCH_LIBRARY_IMPL(torchao, CUDA, m) {
+  m.impl("torchao::rowwise_scaled_linear_cutlass_s8s4",
+         &rowwise_scaled_linear_cutlass_s8s4);
+}
+
+}  // namespace torchao
diff --git a/torchao/csrc/cuda/s8s4_linear_cutlass/s4s4_linear_cutlass.cu b/torchao/csrc/cuda/s8s4_linear_cutlass/s4s4_linear_cutlass.cu
deleted file mode 100644
index 8faf13c..0000000
--- a/torchao/csrc/cuda/s8s4_linear_cutlass/s4s4_linear_cutlass.cu
+++ /dev/null
@@ -1,268 +0,0 @@
-#include <torch/extension.h>
-
-#include <ATen/ATen.h>
-#include <ATen/core/Tensor.h>
-#include <ATen/cuda/CUDAUtils.h>
-#include <c10/util/Exception.h>
-
-#if defined(TORCHAO_USE_CUTLASS) && !defined(_WIN32) &&                   \
-    defined(CUDA_VERSION) && (CUDA_VERSION >= 11080)
-#define BUILD_S4S4_LINEAR_CUTLASS
-#endif
-
-#if defined(BUILD_S4S4_LINEAR_CUTLASS)
-#include "scaled_linear.h"
-#include <cuda_runtime.h>
-#include <cutlass/cutlass.h>
-#include <cutlass/gemm/device/gemm_universal.h>
-#endif
-
-namespace torchao {
-
-#if defined(BUILD_S4S4_LINEAR_CUTLASS)
-
-template<typename... Types>
-static void select_config(
-    const at::Tensor& tensor_a, const at::Tensor& tensor_a_scale,
-    const at::Tensor& tensor_b, const at::Tensor& tensor_b_scale,
-    const at::Tensor& tensor_c, at::Tensor& tensor_d) {
-  const auto dprops = at::cuda::getCurrentDeviceProperties();
-  const auto is_sm8x = dprops->major >= 8;
-
-  if (is_sm8x) {
-    using ElementA = cutlass::int4b_t;
-    using ElementB = cutlass::int4b_t;
-    using ElementAccumulator = int32_t;
-
-    using ThreadblockShape = cutlass::gemm::GemmShape<128, 256, 128>;
-    using WarpShape        = cutlass::gemm::GemmShape<64, 64, 128>;
-    using InstructionShape = cutlass::gemm::GemmShape<16, 8, 64>;
-    constexpr auto NumStages = 3;
-    using Operator = cutlass::arch::OpMultiplyAddSaturate;
-    // using Operator = cutlass::arch::OpMultiplyAddMixedInputUpcast;  // this does not work
-    using ThreadblockSwizzle =
-      cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>;
-
-    scaled_linear_kernel_cutlass_sm8x<
-      ThreadblockShape, WarpShape, InstructionShape, NumStages,
-      ThreadblockSwizzle, ElementA, ElementB, ElementAccumulator, Operator,
-      Types...>(
-          tensor_a, tensor_a_scale, tensor_b, tensor_b_scale, tensor_c,
-          tensor_d);
-    return;
-  }
-
-  TORCH_CHECK(false,
-              __func__, " : Operator not supported on SM", dprops->major, ".",
-              dprops->minor, " for given operands");
-}
-
-template<typename ElementAScale, typename ElementBScale, typename ElementOutput>
-static void
-dispatch_on_tensor_c(
-    const at::Tensor& tensor_a, const at::Tensor& tensor_a_scale,
-    const at::Tensor& tensor_b, const at::Tensor& tensor_b_scale,
-    const at::Tensor& tensor_c, at::Tensor& tensor_d) {
-  if (tensor_c.numel() == 0) {
-    using ElementC = ElementOutput;
-    using UseTensorC = std::false_type;
-    select_config<
-      ElementAScale, ElementBScale, ElementC, UseTensorC, ElementOutput>(
-          tensor_a, tensor_a_scale, tensor_b, tensor_b_scale, tensor_c,
-          tensor_d);
-    return;
-  }
-
-  using UseTensorC = std::true_type;
-  if (tensor_c.scalar_type() == at::ScalarType::Half) {
-    using ElementC = cutlass::half_t;
-    select_config<
-      ElementAScale, ElementBScale, ElementC, UseTensorC, ElementOutput>(
-          tensor_a, tensor_a_scale, tensor_b, tensor_b_scale, tensor_c,
-          tensor_d);
-    return;
-  } else if (tensor_c.scalar_type() == at::ScalarType::BFloat16) {
-    using ElementC = cutlass::bfloat16_t;
-    select_config<
-      ElementAScale, ElementBScale, ElementC, UseTensorC, ElementOutput>(
-          tensor_a, tensor_a_scale, tensor_b, tensor_b_scale, tensor_c,
-          tensor_d);
-    return;
-  }
-
-  TORCH_CHECK(false,
-              __func__, " : Operator not supported for datatype ",
-                tensor_c.scalar_type(), " for addend");
-}
-
-static void
-dispatch_on_tensor_a_scale_and_tensor_b_scale(
-    const at::Tensor& tensor_a, const at::Tensor& tensor_a_scale,
-    const at::Tensor& tensor_b, const at::Tensor& tensor_b_scale,
-    const at::Tensor& tensor_c, at::Tensor& tensor_d) {
-  TORCH_CHECK(tensor_d.scalar_type() == tensor_a_scale.scalar_type(),
-              __func__, " : Operator not supported for output datatype ",
-              tensor_d.scalar_type(), " as it's different from the first ",
-              " operand scale datatype ", tensor_a_scale.scalar_type());
-
-  if (tensor_a_scale.scalar_type() == at::ScalarType::Half &&
-      tensor_b_scale.scalar_type() == at::ScalarType::Half) {
-    using ElementAScale = cutlass::half_t;
-    using ElementBScale = cutlass::half_t;
-    using ElementOutput = cutlass::half_t;
-    dispatch_on_tensor_c<ElementAScale, ElementBScale, ElementOutput>(
-        tensor_a, tensor_a_scale, tensor_b, tensor_b_scale, tensor_c, tensor_d);
-    return;
-  } else if (tensor_a_scale.scalar_type() == at::ScalarType::BFloat16 &&
-             tensor_b_scale.scalar_type() == at::ScalarType::BFloat16) {
-    using ElementAScale = cutlass::bfloat16_t;
-    using ElementBScale = cutlass::bfloat16_t;
-    using ElementOutput = cutlass::bfloat16_t;
-    dispatch_on_tensor_c<ElementAScale, ElementBScale, ElementOutput>(
-        tensor_a, tensor_a_scale, tensor_b, tensor_b_scale, tensor_c, tensor_d);
-    return;
-  }
-
-  TORCH_CHECK(false,
-              __func__, " : Operator not supported for combination of data ",
-              "types ", tensor_a_scale.scalar_type(),
-              " for first operand scale and ", tensor_b_scale.scalar_type(),
-                " for second operand scale");
-}
-
-static void
-check_inputs(
-    const at::Tensor& xq, const at::Tensor& x_scale, const at::Tensor& wq,
-    const at::Tensor& w_scale, const at::Tensor& bias) {
-  // Validate layouts of arguments.
-  TORCH_CHECK(xq.dim() >= 2,
-              __func__, " : Expected xq argument to be 2D or "
-              "higher-dimensional tensor, got ", xq.dim(), " dims");
-  TORCH_CHECK(xq.layout() == at::Layout::Strided,
-              __func__, " : Expected xq argument to be strided, got layout ",
-              xq.layout());
-  TORCH_CHECK(x_scale.dim() == xq.dim() - 1,
-              __func__, " : Expected xq scale argument to be ", xq.dim() - 1,
-              "D tensor, got ", x_scale.dim(), " dims");
-  TORCH_CHECK(x_scale.layout() == at::Layout::Strided,
-              __func__, " : Expected xq scale argument to be strided, got "
-              "layout ", x_scale.layout());
-  TORCH_CHECK(wq.dim() == 2,
-              __func__, " : Expected wq argument to be 2D tensor, got ",
-              wq.dim(), " dims");
-  TORCH_CHECK(wq.layout() == at::Layout::Strided,
-              __func__, " : Expected wq argument to be strided, got layout ",
-              wq.layout());
-  TORCH_CHECK(w_scale.dim() == 1 || w_scale.dim() == 2,
-              __func__, " : Expected wq scale argument to be 1D or 2D tensor, ",
-              "got ", w_scale.dim(), " dims");
-  TORCH_CHECK(w_scale.layout() == at::Layout::Strided,
-              __func__, " : Expected wq scale argument to be strided, got "
-              "layout ", w_scale.layout());
-  if (bias.numel() > 0) {
-    TORCH_CHECK(bias.dim() == 1,
-                __func__, " : Expected bias argument to be 1D tensor, got ",
-                bias.dim(), " dims");
-    TORCH_CHECK(bias.layout() == at::Layout::Strided,
-                __func__, " : Expected bias argument to be strided, got ",
-                "layout ", bias.layout());
-  }
-
-  // Validate sizes of arguments.
-  const auto xq_sizes = xq.sizes().vec();
-  TORCH_CHECK(xq_sizes.back() == wq.size(1),
-              __func__, " : Expected xq argument to have ", wq.size(1),
-              " columns, but got ", xq_sizes.back());
-  const auto x_scale_sizes = x_scale.sizes().vec();
-  for (auto i = 0; i < x_scale_sizes.size(); ++i)
-    TORCH_CHECK(x_scale_sizes[i] == xq_sizes[i],
-                __func__, " : Expected xq scale argument size at position ",
-                i, " to be ", xq_sizes[i], ", but got ", x_scale_sizes[i]);
-  TORCH_CHECK(w_scale.numel() == wq.size(0),
-              __func__, " : Expected wq scale argument to have ", wq.size(0),
-              " elements, got ", w_scale.numel(), " elements");
-  if (bias.numel() > 0) {
-    TORCH_CHECK(bias.numel() == wq.size(0),
-                __func__, " : Expected bias argument to have ", wq.size(0),
-                " elements, got ", bias.numel(), " elements");
-  }
-
-  // Validate strides of arguments.
-  const auto xq_strides = xq.strides();
-  TORCH_CHECK(xq_strides[xq_strides.size() - 1] == 1,
-              __func__, " : Expected xq argument in row-major layout");
-  auto xq_stride_expected = xq_strides[xq_strides.size() - 2];
-  for (int i = xq_strides.size() - 3; i >= 0; --i) {
-    xq_stride_expected *= xq_sizes[i + 1];
-    TORCH_CHECK(xq_strides[i] == xq_stride_expected,
-                __func__, " : Expected xq argument in row-major layout");
-  }
-  TORCH_CHECK(x_scale.is_contiguous(),
-              __func__, " : Expected xq scale argument to be contiguous");
-  const auto wq_strides = wq.strides();
-  TORCH_CHECK(wq_strides[0] >= 1 && wq_strides[1] == 1,
-              __func__, " : Expected wq argument in row-major layout");
-  TORCH_CHECK(w_scale.is_contiguous(),
-              __func__, " : Expected wq scale argument to be contiguous");
-  if (bias.numel() > 0) {
-    const auto bias_strides = bias.strides();
-    TORCH_CHECK(bias_strides[0] == 1,
-                __func__, " : Expected bias argument to be contiguous");
-  }
-}
-#endif
-
-// Perform linear operation, using corresponding CUTLASS mixed
-// data-types GEMM kernel, to given arguments:
-//   result = (xq * x_scale) @ (wq * w_scale).T + bias
-// Notes: The "x_scale" tensor is expected to be a vector, of size
-// equal to number of rows of "xq" tensor.  The "w_scale" tensor is
-// expected to be a vector, of size equal to number of rows of "wq"
-// tensor. The "bias" tensor is expected to be a vector, of size equal
-// to number of rows of "wq" tensor.
-at::Tensor
-s4s4_linear_cutlass(
-    const at::Tensor& xq, const at::Tensor& x_scale, const at::Tensor& wq,
-    const at::Tensor& w_scale, const at::Tensor& bias) {
-#if defined(BUILD_S4S4_LINEAR_CUTLASS)
-  // Check inputs.
-  check_inputs(xq, x_scale, wq, w_scale, bias);
-
-  // Squash the input tensors as appropriate.
-  const auto xq_sizes = xq.sizes().vec();
-  const auto xq_2d = xq.reshape({-1, xq_sizes.back()});
-  const auto x_scale_sizes = x_scale.sizes().vec();
-  const auto x_scale_1d = x_scale.reshape({-1});
-  const auto w_scale_1d = w_scale.reshape({-1});
-
-  // Introduce alias names for arguments, according to the CUTLASS
-  // naming conventions.
-  const auto& tensor_a = xq_2d;
-  const auto& tensor_a_scale = x_scale_1d;
-  const auto& tensor_b = wq;
-  const auto& tensor_b_scale = w_scale_1d;
-  const auto& tensor_c = bias;
-
-  // Create output tensor.
-  at::Tensor tensor_d =
-      tensor_a_scale.new_empty({tensor_a.size(0), tensor_b.size(0)});
-
-  // Dispatch to appropriate kernel template.
-  dispatch_on_tensor_a_scale_and_tensor_b_scale(
-      tensor_a, tensor_a_scale, tensor_b, tensor_b_scale, tensor_c, tensor_d);
-
-  // Reshape and return output tensor.
-  auto tensor_d_sizes = xq_sizes;
-  tensor_d_sizes.back() = wq.size(0);
-  return tensor_d.reshape(tensor_d_sizes);
-#else
-  TORCH_CHECK_NOT_IMPLEMENTED(false, __func__);
-  return at::Tensor{};
-#endif
-}
-
-TORCH_LIBRARY_IMPL(torchao, CUDA, m) {
-  m.impl("torchao::s4s4_linear_cutlass", &s4s4_linear_cutlass);
-}
-
-}  // namespace torchao
diff --git a/torchao/csrc/cuda/s8s4_linear_cutlass/s8s4_linear_cutlass.cu b/torchao/csrc/cuda/s8s4_linear_cutlass/s8s4_linear_cutlass.cu
deleted file mode 100644
index 53eaf53..0000000
--- a/torchao/csrc/cuda/s8s4_linear_cutlass/s8s4_linear_cutlass.cu
+++ /dev/null
@@ -1,315 +0,0 @@
-#include <torch/extension.h>
-
-#include <ATen/ATen.h>
-#include <ATen/core/Tensor.h>
-#include <ATen/cuda/CUDAUtils.h>
-#include <c10/util/Exception.h>
-
-#if defined(TORCHAO_USE_CUTLASS) && !defined(_WIN32) &&                   \
-    defined(CUDA_VERSION) && (CUDA_VERSION >= 11080)
-#define BUILD_S8S4_LINEAR_CUTLASS
-#endif
-
-#if defined(BUILD_S8S4_LINEAR_CUTLASS)
-#include "scaled_linear.h"
-#include <cuda_runtime.h>
-#include <cutlass/cutlass.h>
-#include <cutlass/gemm/device/gemm_universal.h>
-#endif
-
-namespace torchao {
-
-#if defined(BUILD_S8S4_LINEAR_CUTLASS)
-
-template<typename ElementA, typename ElementB, typename... Types>
-static void select_config(
-    const at::Tensor& tensor_a, const at::Tensor& tensor_a_scale,
-    const at::Tensor& tensor_b, const at::Tensor& tensor_b_scale,
-    const at::Tensor& tensor_c, at::Tensor& tensor_d) {
-  const auto dprops = at::cuda::getCurrentDeviceProperties();
-  const auto is_sm8x = dprops->major == 8;
-
-  if (is_sm8x) {
-    if constexpr (std::is_same<ElementA, int8_t>::value &&
-                  std::is_same<ElementB, cutlass::int4b_t>::value) {
-      using ThreadblockSwizzle =
-        cutlass::gemm::threadblock::ThreadblockSwizzleStreamK;
-      using InstructionShape = cutlass::gemm::GemmShape<16, 8, 32>;
-
-      // A minimal heuristic to improve performance for small number
-      // of inputs cases.
-      if (tensor_a.size(0) <= 16) {
-        using ThreadblockShape = cutlass::gemm::GemmShape<16, 128, 128>;
-        using WarpShape = cutlass::gemm::GemmShape<16, 32, 128>;
-        constexpr auto NumStages = 6;
-        scaled_linear_kernel_cutlass_sm8x<
-          ThreadblockShape, WarpShape, InstructionShape, NumStages,
-          ThreadblockSwizzle, ElementA, ElementB, Types...>(
-              tensor_a, tensor_a_scale, tensor_b, tensor_b_scale, tensor_c,
-              tensor_d);
-      } else if (tensor_a.size(0) <= 32) {
-        using ThreadblockShape = cutlass::gemm::GemmShape<32, 128, 128>;
-        using WarpShape = cutlass::gemm::GemmShape<32, 32, 128>;
-        constexpr auto NumStages = 5;
-        scaled_linear_kernel_cutlass_sm8x<
-          ThreadblockShape, WarpShape, InstructionShape, NumStages,
-          ThreadblockSwizzle, ElementA, ElementB, Types...>(
-              tensor_a, tensor_a_scale, tensor_b, tensor_b_scale, tensor_c,
-              tensor_d);
-      } else {
-        using ThreadblockShape = cutlass::gemm::GemmShape<64, 128, 128>;
-        using WarpShape = cutlass::gemm::GemmShape<64, 32, 128>;
-        constexpr auto NumStages = 4;
-        scaled_linear_kernel_cutlass_sm8x<
-          ThreadblockShape, WarpShape, InstructionShape, NumStages,
-          ThreadblockSwizzle, ElementA, ElementB, Types...>(
-              tensor_a, tensor_a_scale, tensor_b, tensor_b_scale, tensor_c,
-              tensor_d);
-      }
-      return;
-    }
-  }
-
-  TORCH_CHECK(false,
-              __func__, " : Operator not supported on SM", dprops->major, ".",
-              dprops->minor, " for given operands");
-}
-
-template<typename... Types>
-static void
-dispatch_on_tensor_a_and_tensor_b(
-    const at::Tensor& tensor_a, const at::Tensor& tensor_a_scale,
-    const at::Tensor& tensor_b, const at::Tensor& tensor_b_scale,
-    const at::Tensor& tensor_c, at::Tensor& tensor_d) {
-  if (tensor_a.scalar_type() == at::ScalarType::Char) {
-    if (tensor_b.scalar_type() == at::ScalarType::Char) {
-      if (tensor_a.size(1) == 2 * tensor_b.size(1)) {
-        using ElementA = int8_t;
-        using ElementB = cutlass::int4b_t;
-        using ElementAccumulator = int32_t;
-        using Operator = cutlass::arch::OpMultiplyAddMixedInputUpcast;
-        select_config<
-          ElementA, ElementB, ElementAccumulator, Operator, Types...>(
-            tensor_a, tensor_a_scale, tensor_b, tensor_b_scale, tensor_c,
-            tensor_d);
-      }
-      return;
-    }
-  }
-
-  TORCH_CHECK(false,
-              __func__, " : Operator not supported for combination of data ",
-              "types ", tensor_a.scalar_type(), " for first operand and ",
-              tensor_b.scalar_type(), " for second operand");
-}
-
-
-template<typename ElementAScale, typename ElementBScale, typename ElementOutput>
-static void
-dispatch_on_tensor_c(
-    const at::Tensor& tensor_a, const at::Tensor& tensor_a_scale,
-    const at::Tensor& tensor_b, const at::Tensor& tensor_b_scale,
-    const at::Tensor& tensor_c, at::Tensor& tensor_d) {
-  if (tensor_c.numel() == 0) {
-    using ElementC = ElementOutput;
-    using UseTensorC = std::false_type;
-    dispatch_on_tensor_a_and_tensor_b<
-      ElementAScale, ElementBScale, ElementC, UseTensorC, ElementOutput>(
-          tensor_a, tensor_a_scale, tensor_b, tensor_b_scale, tensor_c,
-          tensor_d);
-    return;
-  }
-
-  using UseTensorC = std::true_type;
-  if (tensor_c.scalar_type() == at::ScalarType::Half) {
-    using ElementC = cutlass::half_t;
-    dispatch_on_tensor_a_and_tensor_b<
-      ElementAScale, ElementBScale, ElementC, UseTensorC, ElementOutput>(
-          tensor_a, tensor_a_scale, tensor_b, tensor_b_scale, tensor_c,
-          tensor_d);
-    return;
-  } else if (tensor_c.scalar_type() == at::ScalarType::BFloat16) {
-    using ElementC = cutlass::bfloat16_t;
-    dispatch_on_tensor_a_and_tensor_b<
-      ElementAScale, ElementBScale, ElementC, UseTensorC, ElementOutput>(
-          tensor_a, tensor_a_scale, tensor_b, tensor_b_scale, tensor_c,
-          tensor_d);
-    return;
-  }
-
-  TORCH_CHECK(false,
-              __func__, " : Operator not supported for datatype ",
-                tensor_c.scalar_type(), " for addend");
-}
-
-static void
-dispatch_on_tensor_a_scale_and_tensor_b_scale(
-    const at::Tensor& tensor_a, const at::Tensor& tensor_a_scale,
-    const at::Tensor& tensor_b, const at::Tensor& tensor_b_scale,
-    const at::Tensor& tensor_c, at::Tensor& tensor_d) {
-  TORCH_CHECK(tensor_d.scalar_type() == tensor_a_scale.scalar_type(),
-              __func__, " : Operator not supported for output datatype ",
-              tensor_d.scalar_type(), " as it's different from the first ",
-              " operand scale datatype ", tensor_a_scale.scalar_type());
-
-  if (tensor_a_scale.scalar_type() == at::ScalarType::Half &&
-      tensor_b_scale.scalar_type() == at::ScalarType::Half) {
-    using ElementAScale = cutlass::half_t;
-    using ElementBScale = cutlass::half_t;
-    using ElementOutput = cutlass::half_t;
-    dispatch_on_tensor_c<ElementAScale, ElementBScale, ElementOutput>(
-        tensor_a, tensor_a_scale, tensor_b, tensor_b_scale, tensor_c, tensor_d);
-    return;
-  } else if (tensor_a_scale.scalar_type() == at::ScalarType::BFloat16 &&
-             tensor_b_scale.scalar_type() == at::ScalarType::BFloat16) {
-    using ElementAScale = cutlass::bfloat16_t;
-    using ElementBScale = cutlass::bfloat16_t;
-    using ElementOutput = cutlass::bfloat16_t;
-    dispatch_on_tensor_c<ElementAScale, ElementBScale, ElementOutput>(
-        tensor_a, tensor_a_scale, tensor_b, tensor_b_scale, tensor_c, tensor_d);
-    return;
-  }
-
-  TORCH_CHECK(false,
-              __func__, " : Operator not supported for combination of data ",
-              "types ", tensor_a_scale.scalar_type(),
-              " for first operand scale and ", tensor_b_scale.scalar_type(),
-                " for second operand scale");
-}
-
-static void
-check_inputs(
-    const at::Tensor& xq, const at::Tensor& x_scale, const at::Tensor& wq,
-    const at::Tensor& w_scale, const at::Tensor& bias) {
-  // Validate layouts of arguments.
-  TORCH_CHECK(xq.dim() >= 2,
-              __func__, " : Expected xq argument to be 2D or "
-              "higher-dimensional tensor, got ", xq.dim(), " dims");
-  TORCH_CHECK(xq.layout() == at::Layout::Strided,
-              __func__, " : Expected xq argument to be strided, got layout ",
-              xq.layout());
-  TORCH_CHECK(x_scale.dim() == xq.dim() - 1,
-              __func__, " : Expected xq scale argument to be ", xq.dim() - 1,
-              "D tensor, got ", x_scale.dim(), " dims");
-  TORCH_CHECK(x_scale.layout() == at::Layout::Strided,
-              __func__, " : Expected xq scale argument to be strided, got "
-              "layout ", x_scale.layout());
-  TORCH_CHECK(wq.dim() == 2,
-              __func__, " : Expected wq argument to be 2D tensor, got ",
-              wq.dim(), " dims");
-  TORCH_CHECK(wq.layout() == at::Layout::Strided,
-              __func__, " : Expected wq argument to be strided, got layout ",
-              wq.layout());
-  TORCH_CHECK(w_scale.dim() == 1 || w_scale.dim() == 2,
-              __func__, " : Expected wq scale argument to be 1D or 2D tensor, ",
-              "got ", w_scale.dim(), " dims");
-  TORCH_CHECK(w_scale.layout() == at::Layout::Strided,
-              __func__, " : Expected wq scale argument to be strided, got "
-              "layout ", w_scale.layout());
-  if (bias.numel() > 0) {
-    TORCH_CHECK(bias.dim() == 1,
-                __func__, " : Expected bias argument to be 1D tensor, got ",
-                bias.dim(), " dims");
-    TORCH_CHECK(bias.layout() == at::Layout::Strided,
-                __func__, " : Expected bias argument to be strided, got ",
-                "layout ", bias.layout());
-  }
-
-  // Validate sizes of arguments.
-  const auto xq_sizes = xq.sizes().vec();
-  TORCH_CHECK(xq_sizes.back() == 2 * wq.size(1),
-              __func__, " : Expected xq argument to have ", 2 * wq.size(1),
-              " columns, but got ", xq_sizes.back());
-  const auto x_scale_sizes = x_scale.sizes().vec();
-  for (auto i = 0; i < x_scale_sizes.size(); ++i)
-    TORCH_CHECK(x_scale_sizes[i] == xq_sizes[i],
-                __func__, " : Expected xq scale argument size at position ",
-                i, " to be ", xq_sizes[i], ", but got ", x_scale_sizes[i]);
-  TORCH_CHECK(w_scale.numel() == wq.size(0),
-              __func__, " : Expected wq scale argument to have ", wq.size(0),
-              " elements, got ", w_scale.numel(), " elements");
-  if (bias.numel() > 0) {
-    TORCH_CHECK(bias.numel() == wq.size(0),
-                __func__, " : Expected bias argument to have ", wq.size(0),
-                " elements, got ", bias.numel(), " elements");
-  }
-
-  // Validate strides of arguments.
-  const auto xq_strides = xq.strides();
-  TORCH_CHECK(xq_strides[xq_strides.size() - 1] == 1,
-              __func__, " : Expected xq argument in row-major layout");
-  auto xq_stride_expected = xq_strides[xq_strides.size() - 2];
-  for (int i = xq_strides.size() - 3; i >= 0; --i) {
-    xq_stride_expected *= xq_sizes[i + 1];
-    TORCH_CHECK(xq_strides[i] == xq_stride_expected,
-                __func__, " : Expected xq argument in row-major layout");
-  }
-  TORCH_CHECK(x_scale.is_contiguous(),
-              __func__, " : Expected xq scale argument to be contiguous");
-  const auto wq_strides = wq.strides();
-  TORCH_CHECK(wq_strides[0] >= 1 && wq_strides[1] == 1,
-              __func__, " : Expected wq argument in row-major layout");
-  TORCH_CHECK(w_scale.is_contiguous(),
-              __func__, " : Expected wq scale argument to be contiguous");
-  if (bias.numel() > 0) {
-    const auto bias_strides = bias.strides();
-    TORCH_CHECK(bias_strides[0] == 1,
-                __func__, " : Expected bias argument to be contiguous");
-  }
-}
-#endif
-
-// Perform linear operation, using corresponding CUTLASS mixed
-// data-types GEMM kernel, to given arguments:
-//   result = (xq * x_scale) @ (wq * w_scale).T + bias
-// Notes: The "x_scale" tensor is expected to be a vector, of size
-// equal to number of rows of "xq" tensor.  The "w_scale" tensor is
-// expected to be a vector, of size equal to number of rows of "wq"
-// tensor. The "bias" tensor is expected to be a vector, of size equal
-// to number of rows of "wq" tensor.
-at::Tensor
-s8s4_linear_cutlass(
-    const at::Tensor& xq, const at::Tensor& x_scale, const at::Tensor& wq,
-    const at::Tensor& w_scale, const at::Tensor& bias) {
-#if defined(BUILD_S8S4_LINEAR_CUTLASS)
-  // Check inputs.
-  check_inputs(xq, x_scale, wq, w_scale, bias);
-
-  // Squash the input tensors as appropriate.
-  const auto xq_sizes = xq.sizes().vec();
-  const auto xq_2d = xq.reshape({-1, xq_sizes.back()});
-  const auto x_scale_sizes = x_scale.sizes().vec();
-  const auto x_scale_1d = x_scale.reshape({-1});
-  const auto w_scale_1d = w_scale.reshape({-1});
-
-  // Introduce alias names for arguments, according to the CUTLASS
-  // naming conventions.
-  const auto& tensor_a = xq_2d;
-  const auto& tensor_a_scale = x_scale_1d;
-  const auto& tensor_b = wq;
-  const auto& tensor_b_scale = w_scale_1d;
-  const auto& tensor_c = bias;
-
-  // Create output tensor.
-  at::Tensor tensor_d =
-      tensor_a_scale.new_empty({tensor_a.size(0), tensor_b.size(0)});
-
-  // Dispatch to appropriate kernel template.
-  dispatch_on_tensor_a_scale_and_tensor_b_scale(
-      tensor_a, tensor_a_scale, tensor_b, tensor_b_scale, tensor_c, tensor_d);
-
-  // Reshape and return output tensor.
-  auto tensor_d_sizes = xq_sizes;
-  tensor_d_sizes.back() = wq.size(0);
-  return tensor_d.reshape(tensor_d_sizes);
-#else
-  TORCH_CHECK_NOT_IMPLEMENTED(false, __func__);
-  return at::Tensor{};
-#endif
-}
-
-TORCH_LIBRARY_IMPL(torchao, CUDA, m) {
-  m.impl("torchao::s8s4_linear_cutlass", &s8s4_linear_cutlass);
-}
-
-}  // namespace torchao
diff --git a/torchao/csrc/cuda/s8s4_linear_cutlass/scaled_linear.h b/torchao/csrc/cuda/s8s4_linear_cutlass/scaled_linear.h
deleted file mode 100644
index 991384b..0000000
--- a/torchao/csrc/cuda/s8s4_linear_cutlass/scaled_linear.h
+++ /dev/null
@@ -1,288 +0,0 @@
-#include <torch/extension.h>
-
-#include <ATen/ATen.h>
-#include <ATen/core/Tensor.h>
-#include <ATen/cuda/CUDAUtils.h>
-#include <c10/util/Exception.h>
-
-#include <cuda_runtime.h>
-#include <cutlass/cutlass.h>
-#include <cutlass/gemm/device/gemm_universal.h>
-#include <cutlass/epilogue/threadblock/fusion/visitors.hpp>
-#include <cutlass/gemm/kernel/default_gemm_universal_with_visitor.h>
-#include "cutlass/gemm/device/gemm_universal_adapter.h"
-
-#define CUTLASS_STATUS_CHECK(status)                                      \
-  {                                                                       \
-    TORCH_CHECK(status == cutlass::Status::kSuccess,                      \
-                __func__, " : Got CUTLASS error: ",                       \
-                cutlassGetStatusString(status));                          \
-  }
-
-namespace torchao {
-
-template<
-    typename ThreadblockShape,
-    typename WarpShape,
-    typename InstructionShape,
-    int NumStages,
-    typename ThreadblockSwizzle,
-    typename ElementA,
-    typename ElementB,
-    typename ElementAccumulator,
-    typename Operator,
-    typename ElementAScale,
-    typename ElementBScale,
-    typename ElementC,
-    typename UseTensorC,
-    typename ElementOutput>
-void scaled_linear_kernel_cutlass_sm8x(
-    const at::Tensor& tensor_a, const at::Tensor& tensor_a_scale,
-    const at::Tensor& tensor_b, const at::Tensor& tensor_b_scale,
-    const at::Tensor& tensor_c, at::Tensor& tensor_d) {
-  using SmArch = cutlass::arch::Sm80;
-
-  using LayoutA = cutlass::layout::RowMajor;
-  using LayoutB = cutlass::layout::ColumnMajor;
-  using LayoutOutput = cutlass::layout::RowMajor;
-
-  using ElementEpilogue = float;
-  constexpr auto NumEVTEpilogueStages = 1;
-
-  const int m = tensor_a.size(0);
-  const int n = tensor_b.size(0);
-  const int k = std::is_same<ElementA, cutlass::int4b_t>::value ?
-                tensor_a.size(1) * 2 :
-                tensor_a.size(1);
-
-  constexpr int AlignmentA = 128 / cutlass::sizeof_bits<ElementA>::value;
-  constexpr int AlignmentAScale =
-      128 / cutlass::sizeof_bits<ElementAScale>::value;
-  constexpr int AlignmentB = 128 / cutlass::sizeof_bits<ElementB>::value;
-  constexpr int AlignmentBScale =
-      128 / cutlass::sizeof_bits<ElementBScale>::value;
-  constexpr int AlignmentC = 128 / cutlass::sizeof_bits<ElementC>::value;
-  constexpr int AlignmentOutput =
-      128 / cutlass::sizeof_bits<ElementOutput>::value;
-
-  // Check for current CUTLASS limitations w.r.t. alignments.
-  TORCH_CHECK(k % AlignmentA == 0,
-              __func__, " : Number of columns of tensor A must be divisible ",
-              "by ", AlignmentA);
-  TORCH_CHECK(k % AlignmentB == 0,
-              __func__, " : Number of columns of tensor B must be divisible ",
-              "by ", AlignmentB);
-  TORCH_CHECK(n % AlignmentC == 0,
-              __func__, " : Number of columns of tensor C must be divisible ",
-              "by ", AlignmentC);
-
-  using TensorAScaleTileThreadMap =
-      cutlass::epilogue::threadblock::OutputTileThreadLayout<
-          ThreadblockShape,
-          WarpShape,
-          ElementAScale,
-          AlignmentAScale,
-          NumEVTEpilogueStages>;
-  using TensorBScaleTileThreadMap =
-      cutlass::epilogue::threadblock::OutputTileThreadLayout<
-          ThreadblockShape,
-          WarpShape,
-          ElementBScale,
-          AlignmentBScale,
-          NumEVTEpilogueStages>;
-  using TensorCTileThreadMap =
-      cutlass::epilogue::threadblock::OutputTileThreadLayout<
-          ThreadblockShape,
-          WarpShape,
-          ElementC,
-          AlignmentC,
-          NumEVTEpilogueStages>;
-  using OutputTileThreadMap =
-      cutlass::epilogue::threadblock::OutputTileThreadLayout<
-          ThreadblockShape,
-          WarpShape,
-          ElementOutput,
-          AlignmentOutput,
-          NumEVTEpilogueStages>;
-
-  using Accum = cutlass::epilogue::threadblock::VisitorAccFetch;
-
-  using TensorAScale =
-      cutlass::epilogue::threadblock::VisitorColBroadcast<
-          TensorAScaleTileThreadMap,
-          ElementAScale,
-          cute::Stride<cute::_1, cute::_0, int64_t>>;
-  using TensorAScaleArguments = typename TensorAScale::Arguments;
-
-  using TensorBScale =
-      cutlass::epilogue::threadblock::VisitorRowBroadcast<
-          TensorBScaleTileThreadMap,
-          ElementBScale,
-          cute::Stride<cute::_0, cute::_1, int64_t>>;
-  using TensorBScaleArguments = typename TensorBScale::Arguments;
-
-  using TensorCScalar =
-      cutlass::epilogue::threadblock::VisitorScalarBroadcast<ElementC>;
-  using TensorCTensor =
-      cutlass::epilogue::threadblock::VisitorRowBroadcast<
-          TensorCTileThreadMap,
-          ElementC,
-          cute::Stride<cute::_0, cute::_1, int64_t>>;
-  using TensorC =
-      std::conditional_t<UseTensorC::value, TensorCTensor, TensorCScalar>;
-  using TensorCArguments = typename TensorC::Arguments;
-
-  using ApplyAScale = cutlass::epilogue::threadblock::VisitorCompute<
-      cutlass::multiplies, ElementEpilogue, ElementEpilogue,
-      cutlass::FloatRoundStyle::round_to_nearest
-  >;
-  using EVTApplyAScale = cutlass::epilogue::threadblock::Sm80EVT<
-      ApplyAScale,
-      Accum,
-      TensorAScale>;
-
-  using ApplyBScale = cutlass::epilogue::threadblock::VisitorCompute<
-      cutlass::multiplies, ElementEpilogue, ElementEpilogue,
-      cutlass::FloatRoundStyle::round_to_nearest
-  >;
-  using EVTApplyBScale = cutlass::epilogue::threadblock::Sm80EVT<
-      ApplyBScale,
-      EVTApplyAScale,
-      TensorBScale>;
-
-  using ApplySum = cutlass::epilogue::threadblock::VisitorCompute<
-      cutlass::plus, ElementEpilogue, ElementEpilogue,
-      cutlass::FloatRoundStyle::round_to_nearest
-  >;
-  using EVTApplySum = cutlass::epilogue::threadblock::Sm80EVT<
-      ApplySum,
-      EVTApplyBScale,
-      TensorC>;
-
-  using Output = cutlass::epilogue::threadblock::VisitorAuxStore<
-      OutputTileThreadMap, ElementOutput,
-      cutlass::FloatRoundStyle::round_to_nearest,
-      cute::Stride<int64_t, cute::_1, int64_t> // StrideMNL
-  >;
-
-  using EVTOutput = cutlass::epilogue::threadblock::Sm80EVT<
-      Output,
-      EVTApplySum>;
-
-  using EVTKernel =
-      typename cutlass::gemm::kernel::DefaultGemmWithVisitor<
-      ElementA, LayoutA, cutlass::ComplexTransform::kNone, AlignmentA,
-      ElementB, LayoutB, cutlass::ComplexTransform::kNone, AlignmentB,
-      ElementOutput, LayoutOutput, AlignmentOutput,
-      ElementAccumulator,
-      ElementEpilogue,
-      cutlass::arch::OpClassTensorOp,
-      SmArch,
-      ThreadblockShape,
-      WarpShape,
-      InstructionShape,
-      EVTOutput,
-      ThreadblockSwizzle,
-      NumStages,
-      Operator,
-      NumEVTEpilogueStages
-  >::GemmKernel;
-
-  // GemmUniversalBase doesn't work with W4A4
-  // using Gemm = cutlass::gemm::device::GemmUniversalBase<EVTKernel>;
-  using Gemm = cutlass::gemm::device::GemmUniversalAdapter<EVTKernel>;
-
-  cutlass::gemm::GemmCoord problem_size(m, n, k);
-  constexpr auto SplitKFactor = 1;
-
-  TensorAScaleArguments tensor_a_scale_arguments{
-    (ElementAScale*)tensor_a_scale.data_ptr(),
-    ElementAScale(1),
-    {cute::_1{}, cute::_0{}, problem_size.m()}
-  };
-  TensorBScaleArguments tensor_b_scale_arguments{
-    (ElementBScale*)tensor_b_scale.data_ptr(),
-    ElementBScale(1),
-    {cute::_0{}, cute::_1{}, problem_size.n()}
-  };
-  TensorCArguments tensor_c_arguments{
-    [&]() -> TensorCArguments {
-      if constexpr (UseTensorC::value) {
-        return {(ElementC*)tensor_c.data_ptr(),
-                ElementC(0),
-                {cute::_0{}, cute::_1{}, problem_size.n()}};
-      } else {
-        return {ElementC(0)};
-      }
-    }()
-  };
-  typename Output::Arguments output_arguments{
-    (ElementOutput*)tensor_d.data_ptr(),
-    {problem_size.n(), cute::_1{}, problem_size.mn().product()}
-  };
-  typename EVTOutput::Arguments callback_arguments{
-    {
-      {
-        {
-          {},                        // Accum
-          tensor_a_scale_arguments,  // TensorAScale
-          {}                         // ApplyAScale
-        },                           // EVTApplyAScale
-        tensor_b_scale_arguments,    // TensorBScale
-        {},                          // ApplyBScale
-      },                             // EVTApplyBScale
-      tensor_c_arguments,            // TensorC
-      {}                             // ApplySum
-    },                               // EVTApplySum
-    output_arguments                 // Output
-  };                                 // EVTOutput
-  // constexpr auto AvailSms = -1;
-
-  typename Gemm::Arguments arguments(
-    cutlass::gemm::GemmUniversalMode::kGemm,
-    problem_size,
-    SplitKFactor,
-    callback_arguments,              // arguments of EVT callbacks
-    (ElementA*)tensor_a.data_ptr(),
-    (ElementB*)tensor_b.data_ptr(),
-    nullptr,                         // ptr C (unused)
-    nullptr,                         // ptr D (unused)
-    problem_size.mk().product(),     // batch stride A
-    problem_size.nk().product(),     // batch stride B
-    0,                               // batch stride C (unused)
-    0,                               // batch stride D (unused)
-    problem_size.k(),                // stride A
-    problem_size.k(),                // stride B
-    0,                               // stride C (unused)
-    0
-    // ,                               // stride D (unused)
-    // AvailSms  // GemmUniversalBase requires passing AvailSms, but GemmUniversalAdapter doesn't
-    );
-
-  Gemm gemm_op;
-
-  cutlass::Status status;
-
-  // Verify that GEMM operation with given arguments can be performed
-  // by CUTLASS.
-  status = gemm_op.can_implement(arguments);
-  CUTLASS_STATUS_CHECK(status);
-
-  // Allocate workspace for CUTLASS mixed datatypes GEMM kernel.
-  const auto workspace_size = Gemm::get_workspace_size(arguments);
-  auto workspace = tensor_a.new_empty({(int64_t)workspace_size},
-                                      at::TensorOptions().dtype(at::kByte));
-
-  // Initialize CUTLASS mixed datatypes GEMM object.
-  status = gemm_op.initialize(arguments, workspace.data_ptr(),
-                              at::cuda::getCurrentCUDAStream());
-  CUTLASS_STATUS_CHECK(status);
-
-  // Perform mixed datatypes GEMM operation.
-  status = gemm_op.run(at::cuda::getCurrentCUDAStream());
-  CUTLASS_STATUS_CHECK(status);
-
-  C10_CUDA_KERNEL_LAUNCH_CHECK();
-}
-
-} // namespace torchao
diff --git a/torchao/dtypes/uintx/cutlass_int4_packed_layout.py b/torchao/dtypes/uintx/cutlass_int4_packed_layout.py
index d7374c8..bccbf80 100644
--- a/torchao/dtypes/uintx/cutlass_int4_packed_layout.py
+++ b/torchao/dtypes/uintx/cutlass_int4_packed_layout.py
@@ -144,14 +144,14 @@ def _linear_int8_act_int4_weight_cutlass_check(input_tensor, weight_tensor, bias
 
 
 def _linear_int8_act_int4_weight_cutlass_impl(input_tensor, weight_tensor, bias):
-    from torchao.ops import s8s4_linear_cutlass
+    from torchao.ops import rowwise_scaled_linear_cutlass_s8s4
 
     weight = weight_tensor.tensor_impl.int_data
     weight_scale = weight_tensor.tensor_impl.scale
     input = input_tensor.tensor_impl.int_data
     input_scale = input_tensor.tensor_impl.scale
 
-    out = s8s4_linear_cutlass(input, input_scale, weight, weight_scale, bias)
+    out = rowwise_scaled_linear_cutlass_s8s4(input, input_scale, weight, weight_scale, bias)
 
     return out
 
diff --git a/torchao/ops.py b/torchao/ops.py
index 840dbc0..272b358 100644
--- a/torchao/ops.py
+++ b/torchao/ops.py
@@ -20,11 +20,10 @@ lib.define(
     "marlin_qqq_gemm(Tensor x, Tensor weight_marlin, Tensor s_tok, Tensor s_ch, Tensor s_group, Tensor workspace, int size_m, int size_n, int size_k) -> Tensor"
 )
 lib.define(
-    "s8s4_linear_cutlass(Tensor input, Tensor input_scale, Tensor weight, Tensor weight_scale, Tensor bias) -> Tensor"
+    "rowwise_scaled_linear_cutlass_s4s4(Tensor input, Tensor input_scale, Tensor weight, Tensor weight_scale, Tensor bias) -> Tensor"
 )
-lib.define("int4_mm_cutlass(Tensor A, Tensor B) -> Tensor")
 lib.define(
-    "scaled_int4_mm_cutlass(Tensor A, Tensor B, Tensor row_scale, Tensor col_scale) -> Tensor"
+    "rowwise_scaled_linear_cutlass_s8s4(Tensor input, Tensor input_scale, Tensor weight, Tensor weight_scale, Tensor bias) -> Tensor"
 )
 
 
@@ -518,7 +517,7 @@ def _(
     return torch.empty((size_m, size_n), dtype=torch.float16, device=x.device)
 
 
-def s8s4_linear_cutlass(
+def rowwise_scaled_linear_cutlass_s8s4(
     input: Tensor,
     input_scale: Tensor,
     weight: Tensor,
@@ -526,23 +525,23 @@ def s8s4_linear_cutlass(
     bias: Tensor,
 ) -> Tensor:
     """
-    CUTLASS-based W4A8 linear operator.
+    CUTLASS-based row-wise scaled linear operator.
     Args:
-        input: input tensor, quantized to 8-bit integer values.
+        input: quantized input tensor, in row-major layout.
         input_scale: scale factors for input tensor, has to be tensor of the same shape as the input tensor, minus the last dimension.
-        weight: weight matrix, quantized to 4-bit integer values, in row-major layout.
+        weight: quantized weight matrix, in row-major layout.
         weight_scale: scale factors for weight tensor, one value per row of weight matrix (thus also tensor of the same shape as the weight tensor, minus the last dimension).
         bias: a vector of size equal to number of rows of weight tensor, or None.
     Returns:
         output: result tensor, in row-major layout.
     """
 
-    return torch.ops.torchao.s8s4_linear_cutlass.default(
+    return torch.ops.torchao.rowwise_scaled_linear_cutlass_s8s4.default(
         input, input_scale, weight, weight_scale, bias
     )
 
 
-@register_custom_op("torchao::s8s4_linear_cutlass")
+@register_custom_op("torchao::rowwise_scaled_linear_cutlass_s8s4")
 def _(
     input: Tensor,
     input_scale: Tensor,
@@ -550,6 +549,8 @@ def _(
     weight_scale: Tensor,
     bias: Tensor,
 ) -> Tensor:
+    # FIXME: update this!!!
+
     # Validate dtypes.
     torch._check(
         input.dtype == torch.int8,
@@ -621,29 +622,8 @@ def _(
     )
 
 
-def int4_mm_cutlass(A: Tensor, B: Tensor) -> Tensor:
-    """
-    CUTLASS-based W4A4 matmul.
-    Args:
-        A: first INT4 tensor, packed in INT8 dtype, row-major layout.
-        B: second INT4 tensor, packed in INT8 dtype, column-major layout.
-    Returns:
-        output: result tensor, in row-major layout.
-    """
-    assert A.dtype == B.dtype == torch.int8
-    assert A.ndim == B.ndim == 2
-    assert A.shape[1] == B.shape[0]
-    assert A.is_contiguous() and B.T.is_contiguous()
-    return torch.ops.torchao.int4_mm_cutlass.default(A, B)
-
-
-@register_custom_op("torchao::int4_mm_cutlass")
-def _(A: Tensor, B: Tensor) -> Tensor:
-    return A.new_empty(A.shape[0], B.shape[1], dtype=torch.int32)
-
-
-def scaled_int4_mm_cutlass(
-    A: Tensor, B: Tensor, row_scale: Tensor, col_scale: Tensor
+def rowwise_scaled_linear_cutlass_s4s4(
+        A: Tensor, row_scale: Tensor, B: Tensor, col_scale: Tensor, bias: Tensor
 ) -> Tensor:
     """
     CUTLASS-based W4A4 scaled-matmul.
@@ -656,15 +636,16 @@ def scaled_int4_mm_cutlass(
         output: result tensor, in row-major layout.
     """
     assert A.dtype == B.dtype == torch.int8
-    assert A.ndim == B.ndim == 2
-    assert A.shape[1] == B.shape[0]
-    assert A.is_contiguous() and B.T.is_contiguous()
-    assert row_scale.ndim == col_scale.ndim == 1
+    assert A.ndim >= 2
+    assert B.ndim == 2
+    assert A.shape[-1] == B.shape[-1]
+    assert A.is_contiguous() and B.is_contiguous()
+    assert row_scale.ndim == col_scale.ndim == 2
     assert row_scale.dtype == col_scale.dtype
     assert row_scale.dtype in (torch.float16, torch.bfloat16)
-    return torch.ops.torchao.scaled_int4_mm_cutlass.default(A, B, row_scale, col_scale)
+    return torch.ops.torchao.rowwise_scaled_linear_cutlass_s4s4.default(A, row_scale, B, col_scale, bias)
 
 
-@register_custom_op("torchao::scaled_int4_mm_cutlass")
-def _(A: Tensor, B: Tensor, row_scale: Tensor, col_scale: Tensor) -> Tensor:
+@register_custom_op("torchao::rowwise_scaled_linear_cutlass_s4s4")
+def _(A: Tensor, row_scale: Tensor, B: Tensor, col_scale: Tensor, bias: Tensor) -> Tensor:
     return row_scale.new_empty(A.shape[0], B.shape[1])

I did some renaming, so s8s4_linear_cutlass sub-directory of torchao/csrc/cuda is renamed to rowwise_scaled_linear_cutlass and then "your" files scaled_linear.h, s4s4_linear_cutlass.cu and s8s4_linear_cutlass.cu are renamed respectively to rowwise_scaled_linear_cutlass.cuh, rowwise_scaled_linear_cutlass_s4s4.cu and rowwise_scaled_linear_cutlass_s8s4.cu. The later two files are minimalistic now, they just dispatch according to the input/weight tensor data types, and everything else is in the .cuh header file. I've updated all references to my W4A8 kernel in the code, and its test and benchmark work. I've added changes needed for W4A4 kernel in the torchao/ops.py (actually, just adapted these from your scaled_int4_mm_cutlass that was there), so that I was able to add W4A4 test in tests/test_rowwise_scaled_linear_cutlass.py. Note that, for the sake of running test, this patch temporarily removes your int4_mm_cutlass and scaled_int4_mm_cutlass from torchao/ops.py - but this could be reverted, this patch is primarily to show how I'd like C++ side of changes to look like.

Please take a look and let me know what you think.

Regarding your question about possibility of different types for scales and bias - this pretty much comes for free, and I would be inclined to extend this further so that either of input/weight scales and/or bias could be of different types. On the other side, I have a related question here: looking at your int4_mm_cutlass and scaled_int4_mm_cutlass operators: do you think we may have use for the case when either of input/weight scales and/or bias could be optional? This kernel could be easily extended to support this too, without any run-time penalty.

My TODO list regarding the attached patch:

  1. The W4A4 test produces results that differ from reference results, need to investiage this.
  2. Have to re-examine rowwise_scaled_linear_cutlass_check_inputs method for S4/S4 case.
  3. Need to revisit rowwise_scaled_linear_cutlass_s8s4 and corresponding register_custom_op in the torchao/ops.py. I modeled these according to ones for Marlin-based W4A8 kernel, while on the other side I can see yours are different in the sense that the checks are in the former method instead of in the later. TBH, I don't fully get how this stuff works on the Python side of the torchao; in particular, is there any need for repeating inputs checks that, in the case of this kernel, are already there in the C++ code - do you have any suggestions here?
  4. The scales could be made optional, just like bias is, in the kernel.
  5. The code is to be linted before eventually commited.

I'll continue on this tomorrow.

@gau-nernst
Copy link
Collaborator Author

Thank you for the patch! Will apply and work on it later today.

do you think we may have use for the case when either of input/weight scales and/or bias could be optional?

I don't think there is a standalone real use case for INT4xINT4->INT32 (yet), similarly for INT8xINT8->INT32 i.e. torch._int_mm(), since we will definitely need some scaling anyway, and perf is better with fused output scaling. I initially added that since it might be interesting for benchmarking to see perf of INT4 tensor cores alone. For research, people might find the raw INT4xINT4->INT32 output without scaling useful for whatever reasons they are exploring?

TBH, I don't fully get how this stuff works on the Python side of the torchao; in particular, is there any need for repeating inputs checks that, in the case of this kernel, are already there in the C++ code - do you have any suggestions here?

From my observation, it's a good idea to have input checks for the meta impletation used by torch.compile for shape tracing (otherwise we won't catch errors during compile time, only runtime) -> we need some kind of input checks in Python anyway. Hence, usually I don't bother with input checks in C++ (or triton) side, and put all checks in Python. I place the Python checks in the Python wrapper so that it is shared between meta and CUDA implementations (actually it's regardless of the implementation) -> technically the op doesn't have input checks, but the Python wrapper does. e.g.

def scaled_int8_mm(
A: Tensor, B: Tensor, row_scale: Tensor, col_scale: Tensor
) -> Tensor:
"""Compute `(A @ B) * row_scale * col_scale`, where `A` and `B` are INT8 to utilize
INT8 tensor cores. `col_scale` can be a scalar.
"""
assert A.dtype is torch.int8 and B.dtype is torch.int8
assert row_scale.dtype is col_scale.dtype
assert A.shape[1] == B.shape[0]
assert row_scale.squeeze().shape == (A.shape[0],)
assert col_scale.squeeze().shape in ((B.shape[1],), ())
assert row_scale.is_contiguous()
assert col_scale.is_contiguous()
return torch.ops.torchao.scaled_int8_mm(A, B, row_scale, col_scale)

That's just my personal view (and I prefer writing Python over C++ any time 😆).

@gau-nernst
Copy link
Collaborator Author

The W4A4 test produces results that differ from reference results, need to investiage this.

Found a typo

+  int k = tensor_a.size(1);
+  if constexpr (cutlass::sizeof_bits<ElementA>::value < 8) {
+    k *= 8 % cutlass::sizeof_bits<ElementA>::value;
+  }

Should be (integer) division instead of modulo. After I fixed that, the tests pass.

I also take the liberty to slightly change the test: now matmul, scaling, and bias addition are done in FP32, which should match the numerics of the fused cutlass kernel better, especially for the scaling and bias addition parts, since FP16/BF16 matmul is acccumulated with FP32 anyway. With this change, we can use torch.testing.assert_close() directly. Let me know if you are ok with it.

With your refactoring, it should be trivial to add W8A8. Should we add W8A8 also? (We might not add it to AQT since AQT currently uses torch.compile to codegen this path. Though I think inductor only fuses 1 scaling (either input scale or weight scale), and we can get a bit more perf with fusing both input and weight scales.

Other ideas:

  • Add row-wise scaled FP8 matmul for sm89. Only need to modify logic for ElementAccumulator and SmArch I think. Will try. This can be hooked up to torch._scaled_mm() for sm89 (cc: @drisspg)
  • Runtime kernel params tuning by codegen and compile cutlass kernel at runtime 👀. Basically cutlass python (but hard-coded for row-wise scaling linear)
    • Expand on this point. One strength of triton is autotune at runtime which can adapt the best kernel params for different hardware. Usually CUDA kernels in PyTorch are tuned for A100/H100 (and also hard-coded triton configs), so there are perf left on the table. For example, for INT8 matmul on my consumer GPU, I can get 1.5-2.5x faster than torch._int_mm() simply by having better kernel params in triton.

@alexsamardzic
Copy link
Collaborator

I don't think there is a standalone real use case for INT4xINT4->INT32 (yet), similarly for INT8xINT8->INT32 i.e. torch._int_mm(), since we will definitely need some scaling anyway, and perf is better with fused output scaling. I initially added that since it might be interesting for benchmarking to see perf of INT4 tensor cores alone. For research, people might find the raw INT4xINT4->INT32 output without scaling useful for whatever reasons they are exploring?

I'd prefer as much as possible CUTLASS-based kernels to be implemented through what we came up with; however, for now I'd restrain to what is asked by users (which means we stay at the moment with W4A4/W4A8, and extend it on demand). The reasons is that, through another CUTLASS-based kernel I'm working on, I hope to investigate a more systematic approach to selecting performant configs, and then eventually to update these two variations that we have, and also to have a ready-made approach for new kernels. Good config selection is the hardest thing here - for W4A8 kernel, I've spent considerable time benchmarking manually, in order to be come up with corresponding config selection code, and I still am not very happy with it.

As far as using these kernels as a base for benchmarking: my approach throughout the development is to compare against cutlass_profiler results for the same combination of input data types (cutlass_profiler goes through number of configs, enumerated in python/cutlass_library/generator.py in the CUTLASS source tree, so it make sense to compare with in order to see how config selection peforms), and then when my CUTLASS-based kernel is ready, I'd compare against PyTorch timings on un-quantized data types. So I don't see an use of CUTLASS-based kernels like W8A8 for this purpose, at least not until config selection on our side perfected.

From my observation, it's a good idea to have input checks for the meta impletation used by torch.compile for shape tracing (otherwise we won't catch errors during compile time, only runtime) -> we need some kind of input checks in Python anyway. Hence, usually I don't bother with input checks in C++ (or triton) side, and put all checks in Python. I place the Python checks in the Python wrapper so that it is shared between meta and CUDA implementations (actually it's regardless of the implementation) -> technically the op doesn't have input checks, but the Python wrapper does

My understanding matches what you said. So when the code run in eager mode, the corresponding operator from torhchao/ops.py is executed; and when the code run in compiled mode, then the tracer is too first executing this operator, and then the Dynamo runs code under @register_custom_op for the same operator, in order to get the shape for output fake tensor. So if checks are along with the operator itself (either in Python, or in C++), there is no need to put them elsewhere. But most of the operators in torchao/ops.py do checks under @register_custom_op for given operator - for this reason I did it the same way for my W4A8, but was confused with that.

So I believe that, if we keep checks in C++ code, we can remove all the checks from torchao/ops.py, and we should be good. I agree that it's much easier to write these checks in Python, but I'd be inclined to keep them on the C++ side: most of this is the same for all input data types anyway, and if there are some CUTLASS-imposed limits (like for strides etc.) it seems more natural to keep them there.

Found a typo

...

Should be (integer) division instead of modulo. After I fixed that, the tests pass.

Awesome, sorry for the bug. The tests look great now. Shall we do some effort to extract common stuff between tests? I can do that.

Other ideas:

Add row-wise scaled FP8 matmul for sm89. Only need to modify logic for ElementAccumulator and SmArch I think. Will try. This can be hooked up to torch._scaled_mm() for sm89

I've implemented this through this PR, it's merged in PyTorch. I agree it belongs here, and in the future I'd like to migrate this one as well as some other kernels (I have some CUTLASS-based F16/S8 stuff there, implemented before torchao started) into torchao.

Runtime kernel params tuning by codegen and compile cutlass kernel at runtime 👀. Basically cutlass python (but hard-coded for row-wise scaling linear)

Completely agree with that. There is some support for CUTLASS-based auto-tuning, and I did some work there; but it sort of predates most of CUTLASS Python API development, so it's based on pretty much manually generating all of the C++ code that is alike to what we have in the .cuh file now (except that CUTLASS Python API is used to generate GEMM instantiation). Again, I'd like to revisit this, and have it fully based on the CUTLASS Python API - hopefully once when above mentioned ongoing work on this other CUTLASS-based kernel completes. (Of course, for any of this stuff - any help is more than welcome.)

For this PR, I'll try to complete remaining items from my list:

  1. Will re-check, and update if needed rowwise_scaled_linear_cutlass_check_inputs method.
  2. Remove checks for "my" op in torchao/ops.py.
  3. Refactor test file so that common stuff could be shared between tests.
  4. Add a README into rowwise_scaled_linear_cutlass, to explain how to extend kernel with additional combination of inputs data types.

Will post a patch here.

@gau-nernst
Copy link
Collaborator Author

@alexsamardzic Thank you for the update. Was not aware of the latest works in PyTorch core :)

But most of the operators in torchao/ops.py do checks under @register_custom_op for given operator - for this reason I did it the same way for my W4A8, but was confused with that.

So I believe that, if we keep checks in C++ code, we can remove all the checks from torchao/ops.py, and we should be good.

Just want to point out a small note. Having checks under @register_custom_op (which is basically @torch.library.impl(lib, "op_name", "Meta") I think) means that errors will be caught by the tracer (running through meta impl), before codegen (that's my understanding). If we only have C++ checks, errors will only surface when we run the actual custom op. Not a big deal, just want to make it clear.

Btw, it feels wrong to have you do most of the hard part, but the commit is by me. Can you commit directly to my branch? If not, I can add you as a collaborator of my fork.

@alexsamardzic
Copy link
Collaborator

Just want to point out a small note. Having checks under @register_custom_op (which is basically @torch.library.impl(lib, "op_name", "Meta") I think) means that errors will be caught by the tracer (running through meta impl), before codegen (that's my understanding). If we only have C++ checks, errors will only surface when we run the actual custom op. Not a big deal, just want to make it clear.

You're right, and it kind of sucks. On the other side, if error is reported by the tracer, the error message says something like torch._dynamo.exc.TorchRuntimeError: Failed running call_function torchao.rowwise_scaled_linear_cutlass_s8s4.default(*(FakeTensor(..., device='cuda:0', size=(1, 2, 128), dtype=..., preceded by a lengthy Dynamo traceback. If the error is reported by the op itself, it simply says RuntimeError: rowwise_scaled_linear_cutlass_check_inputs : ..., which to me seems a little bit more specific. So I'm opting for removing checks for W4A8 operator, and for other operators as far as I'm concerned it's completely OK to do some checks there (like yours was doing).

Btw, it feels wrong to have you do most of the hard part, but the commit is by me. Can you commit directly to my branch? If not, I can add you as a collaborator of my fork.

Don't worry about that - it's fine with me, and I should have just one small patch to add.

@alexsamardzic
Copy link
Collaborator

Will post a patch here.

Here it is:
patch.txt

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. topic: new feature Use this tag if this PR adds a new feature
Projects
None yet
Development

Successfully merging this pull request may close these issues.

[Feature Request] W4A4 Quantization Support in torchao
5 participants