forked from IBM/data-prep-kit
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtransform.py
335 lines (306 loc) · 15.6 KB
/
transform.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
# (C) Copyright IBM Corp. 2024.
# Licensed under the Apache License, Version 2.0 (the “License”);
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an “AS IS” BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
################################################################################
import io
import os
import re
from argparse import ArgumentParser, Namespace
from typing import Any, List
import numpy as np
import polars as pl
from data_processing.transform import AbstractFolderTransform, TransformConfiguration
from data_processing.utils import (
CLIArgumentProvider,
TransformUtils,
UnrecoverableException,
get_logger,
)
from dpk_fdedup.Murmur_MH import Murmur_MH
short_name = "cluster"
cli_prefix = f"{short_name}_"
# configuration keys
num_bands_key = "num_bands"
""" This key holds the number of bands used in the banding technique"""
num_segments_key = "num_segments"
""" This key holds the number of segments dividing the hashing space for each band"""
jaccard_similarity_threshold_key = "jaccard_similarity_threshold"
""" This key holds the Jaccard similarity threshold above which two documents are duplicates"""
sort_output_key = "sort_output"
""" This key is used to sort"""
# command line arguments
num_bands_cli_param = f"{cli_prefix}{num_bands_key}"
""" The number of bands used in the banding technique"""
jaccard_similarity_threshold_cli_param = f"{cli_prefix}{jaccard_similarity_threshold_key}"
""" Jaccard similarity threshold above which two documents are duplicates"""
num_segments_cli_param = f"{cli_prefix}{num_segments_key}"
""" The number of segments dividing the hashing space for each band"""
sort_output_cli_param = f"{cli_prefix}{sort_output_key}"
""" Sort the output"""
captured_arg_keys = [
num_bands_key,
num_segments_key,
jaccard_similarity_threshold_key,
sort_output_key,
]
# defaults
num_bands_default = 14
""" Default number of bands used in the banding technique (from FineWeb https://arxiv.org/pdf/2406.17557)"""
jaccard_similarity_threshold_default = 0.75
""" Default Jaccard similarity threshold (from FineWeb https://arxiv.org/pdf/2406.17557)"""
num_segments_default = 1
""" Default number of segments dividing the hashing space for each band"""
sort_output_default = False
class ClusterAnalysisTransform(AbstractFolderTransform):
"""
This is the second transform of the fuzzy dedup pipeline. It runs in parallel:
for each band, the hashing interval is divided into segments. A cluster analysis
uses as input all the parquet files from segment of a band. The `bands` output
of the signature calculation, the first transform in the fuzzy dedup pipeline
contains all the data for a given segment s of a specific band b in the
subfolder `bands/band=b/segment=s`.
The transform loads all the parquet files in the `bands/band=b/segment=s`
subfolder. Each one of these parquet files has two columns: the `band_hash`
and a `data` structure, which includes the `document_id`, the `minhashes` and
the `document_size` fields. Once all the files have been loaded in a single
dataframe, a `group_by` operation on the `band_hash` field is performed in
that dataframe. All the documents that have the same band_hash are grouped
in a cluster. Subsequently, the documents of each cluster are sorted in
descending order according to their size, and a Jaccard similarity is
calculated between the cluster documents. The documents for which the Jaccard
similarity is above the `jaccard_similarity_threshold` remain in the cluster,
the others are removed from the cluster. Finally, from each cluster that has
more than one document after running the Jaccard similarity, we select a doc
to keep (the largest size document), and mark the other documents as
duplicates. The resulting clusters are saved in a file for further analysis.
The following internal variables are initialized from the config parameter:
num_bands: number of bands used in the banding technique
jaccard_similarity_threshold: Jaccard similarity threshold above which two documents are duplicates
num_segments: the number of segments dividing the hashing space for each band
"""
def __init__(self, config: dict[str, Any]):
"""
Initialize based on the dictionary of configuration information.
This is generally called with configuration parsed from the CLI arguments
defined by the companion runtime, ClusterAnalysisTransformRuntime.
"""
super().__init__(config)
self.num_bands = config.get(num_bands_key, num_bands_default)
self.num_segments = config.get(num_segments_key, num_segments_default)
self.jaccard_similarity_threshold = config.get(
jaccard_similarity_threshold_key, jaccard_similarity_threshold_default
)
self.sort_output = config.get(sort_output_key, sort_output_default)
self.data_access = config.get("data_access")
if self.data_access is None:
raise UnrecoverableException("Could not get a pointer to the data access object inside the transform.")
self.logger = get_logger(__name__)
def transform(self, folder_name: str) -> tuple[list[tuple[bytes, str]], dict[str, Any]]:
self.logger.debug(f"Cluster analysis for folder {folder_name}")
metadata = {}
input_folder = TransformUtils.clean_path(os.path.join(self.data_access.input_folder, folder_name))
files, retries = self.data_access.get_folder_files(
path=input_folder,
extensions=[".parquet"],
return_data=True,
)
if retries > 0:
metadata |= {"data_access_retries": retries}
match = re.match(r"^band=(\d+)/segment=(\d+)$", folder_name)
if match:
band = int(match.group(1))
segment = int(match.group(2))
else:
raise ValueError(f"Wrong folder_name {folder_name}, should be band=b/segment=s")
output_folder = TransformUtils.clean_path(self.data_access.output_folder)
output_path = os.path.join(output_folder, f"band_{band}_segment_{segment}.parquet")
# consolidate into a single data frame band hashes computed by workers
band_segment_dataframe, consolidation_stats = self._consolidate_band_segment_files(files)
metadata |= consolidation_stats
# cluster grouping by band hashes
cluster_dataframe, cluster_stats = self._get_clusters(band_segment_dataframe)
metadata |= cluster_stats
# cluster analysis using jaccard similarity
jaccard_cluster_dataframe, jaccard_stats = self._analyze_clusters(cluster_dataframe)
metadata |= jaccard_stats
# Generate the docs_to_remove dataframe
docs_to_remove_dataframe = jaccard_cluster_dataframe.explode("docs_to_remove")
output_data = TransformUtils.convert_arrow_to_binary(docs_to_remove_dataframe.to_arrow())
self.logger.debug(f"{len(docs_to_remove_dataframe)} documents marked to remove")
metadata |= {"num_duplicate_documents": len(docs_to_remove_dataframe)}
return [(output_data, output_path)], metadata
def _consolidate_band_segment_files(self, files: dict[str, bytes]) -> tuple[pl.DataFrame, dict[str, Any]]:
band_segment_dataframe = pl.DataFrame()
total_input_rows = 0
for fname, contents in files.items():
df = pl.read_parquet(io.BytesIO(contents))
total_input_rows += len(df)
self.logger.debug(f"{fname} has {len(df)} rows")
band_segment_dataframe = band_segment_dataframe.vstack(df)
consolidation_stats = {
"input_files": len(files),
"input_bytes": sum(len(v) for v in files.values()),
"input_rows": total_input_rows,
"consolidated_files": 1,
"consolidated_bytes": band_segment_dataframe.to_arrow().nbytes,
"consolidated_rows": len(band_segment_dataframe),
}
return band_segment_dataframe, consolidation_stats
def _get_clusters(self, band_segment_dataframe: pl.DataFrame) -> tuple[pl.DataFrame, dict[str, Any]]:
groupby_dataframe = band_segment_dataframe.group_by("band_hash").agg("document_data")
cluster_dataframe = groupby_dataframe.with_columns(cluster_length=pl.col("document_data").list.len()).filter(
pl.col("cluster_length") > 1
)
# self.logger.info(f"file_name = {file_name}")
num_clusters = len(cluster_dataframe)
if num_clusters > 0:
sum_cdocs = cluster_dataframe.select(pl.sum("cluster_length")).item()
max_cdocs = cluster_dataframe.select(pl.max("cluster_length")).item()
min_cdocs = cluster_dataframe.select(pl.min("cluster_length")).item()
avg_cdocs = cluster_dataframe.select(pl.mean("cluster_length")).item()
else:
sum_cdocs = 0
max_cdocs = 0
min_cdocs = 0
avg_cdocs = 0
self.logger.debug(f"After GroupBy: {num_clusters} clusters with {sum_cdocs} total docs")
self.logger.debug(f" max/min/avg docs per cluster: {max_cdocs}/{min_cdocs}/{avg_cdocs:.2f}")
cluster_stats = {
"groupby_clusters": num_clusters,
"cluster_duplicate_docs": sum_cdocs,
}
return cluster_dataframe, cluster_stats
def _analyze_clusters(self, df: pl.DataFrame) -> tuple[pl.DataFrame, dict[str, Any]]:
# Define the schema with specific data types
schema = {"first_doc": pl.Int64, "docs_to_remove": pl.List(pl.Int64), "docs_to_remove_length": pl.Int64}
doc_ids_lists = []
docs_to_remove_lists = []
len_of_docs2remove_lists = []
for row in df.iter_rows(named=True):
doc_ids_list, docs_to_remove_list, len_of_docs2remove_list = self._jaccard_distance_calculation(row)
doc_ids_lists += doc_ids_list
docs_to_remove_lists += docs_to_remove_list
len_of_docs2remove_lists += len_of_docs2remove_list
jaccard_cluster_dataframe = pl.DataFrame(
{
"first_doc": doc_ids_lists,
"docs_to_remove": docs_to_remove_lists,
"docs_to_remove_length": len_of_docs2remove_lists,
},
schema=schema,
)
filtered_jaccard_dataframe = jaccard_cluster_dataframe.filter(pl.col("docs_to_remove_length") > 0)
num_clusters = len(filtered_jaccard_dataframe)
if num_clusters > 0:
sum_cdocs = filtered_jaccard_dataframe.select(pl.sum("docs_to_remove_length")).item()
max_cdocs = filtered_jaccard_dataframe.select(pl.max("docs_to_remove_length")).item()
min_cdocs = filtered_jaccard_dataframe.select(pl.min("docs_to_remove_length")).item()
avg_cdocs = filtered_jaccard_dataframe.select(pl.mean("docs_to_remove_length")).item()
else:
sum_cdocs = 0
max_cdocs = 0
min_cdocs = 0
avg_cdocs = 0
self.logger.debug(f"After Jaccard: {num_clusters} clusters with {sum_cdocs} total docs")
self.logger.debug(f" max/min/avg docs per cluster: {max_cdocs}/{min_cdocs}/{avg_cdocs:.2f}")
jaccard_stats = {
"jaccard_clusters": num_clusters,
"jaccard_duplicate_docs": sum_cdocs,
}
if self.sort_output:
filtered_jaccard_dataframe = filtered_jaccard_dataframe.sort(by="first_doc")
return filtered_jaccard_dataframe, jaccard_stats
def _jaccard_distance_calculation(self, row: List[pl.Series]) -> list[list]:
# Process row and return a new list of Series or a new row
threshold = self.jaccard_similarity_threshold
doc_ids_list = []
docs_to_remove_list = []
len_of_docs2remove_list = []
# sort documents
document_data = row["document_data"]
# Sort the list by 'document_length'
sorted_document_data = sorted(document_data, key=lambda x: (-x["document_length"], x["int_id_column"]))
# Extracting int_id_column values into a list
doc_list = [item["int_id_column"] for item in sorted_document_data]
# Creating a dictionary with int_id_column as key and minhashes as value
doc_minhashes = {item["int_id_column"]: item["minhashes"] for item in sorted_document_data}
while len(doc_list) > 1:
docs_to_remove = []
new_doc_list = []
# this is the document we are going to keep
first_doc = doc_list[0]
first_mh = doc_minhashes[first_doc]
for int_id_column in doc_list[1:]:
doc_mh = doc_minhashes[int_id_column]
distance = Murmur_MH.jaccard(np.array(first_mh), np.array(doc_mh))
if distance >= threshold:
docs_to_remove.append(int_id_column)
else:
new_doc_list.append(int_id_column)
if len(docs_to_remove) > 0:
docs_to_remove = list(set(docs_to_remove))
doc_ids_list.append(first_doc)
docs_to_remove_list.append(docs_to_remove)
len_of_docs2remove_list.append(len(docs_to_remove))
doc_list = new_doc_list
return doc_ids_list, docs_to_remove_list, len_of_docs2remove_list
class ClusterAnalysisTransformConfiguration(TransformConfiguration):
"""
Provides support for configuring and using the associated Transform class include
configuration with CLI args.
"""
def __init__(self):
super().__init__(
name=short_name,
transform_class=ClusterAnalysisTransform,
remove_from_metadata=[],
)
self.logger = get_logger(__name__, level="INFO")
def add_input_params(self, parser: ArgumentParser) -> None:
"""
Add Transform-specific arguments to the given parser.
This will be included in a dictionary used to initialize the NOOPTransform.
By convention a common prefix should be used for all transform-specific CLI args
(e.g, noop_, pii_, etc.)
"""
parser.add_argument(
f"--{jaccard_similarity_threshold_cli_param}",
type=float,
default=jaccard_similarity_threshold_default,
help="Jaccard similarity threshold above which two documents are duplicates",
)
parser.add_argument(
f"--{num_bands_cli_param}",
type=int,
default=num_bands_default,
help="The number of bands used in the banding technique",
)
parser.add_argument(
f"--{num_segments_cli_param}",
type=int,
default=num_segments_default,
help="The number of segments dividing the hashing space for each band",
)
parser.add_argument(
f"--{sort_output_cli_param}",
type=bool,
default=sort_output_default,
help="Sort the similarity clusters by the document ID of the kept doc (used primarily for testing)",
)
def apply_input_params(self, args: Namespace) -> bool:
"""
Validate and apply the arguments that have been parsed
:param args: user defined arguments.
:return: True, if validate pass or False otherwise
"""
captured = CLIArgumentProvider.capture_parameters(args, cli_prefix, False)
self.params = self.params | captured
self.logger.info(f"{short_name} parameters are : {self.params}")
return True