Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

DCLM Style Deduplications #214

Open
wants to merge 10 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
831 changes: 463 additions & 368 deletions Cargo.lock

Large diffs are not rendered by default.

3 changes: 3 additions & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ flate2 = { version = "1.0.28", features = [
"zlib-ng",
], default-features = false }
glob = "0.3.1"
human_bytes = "0.4.3"
humantime = "2.1"
indicatif = "0.17"
jsonpath-rust = "0.3.0"
Expand All @@ -42,6 +43,7 @@ simple_logger = { version = "3.0", features = [
"colors",
], default-features = false, optional = true }
structopt = { version = "0.3", optional = true }
sysinfo="0.30.7"
thousands = "0.2"
threadpool = "1.8.1"
tokenizers = { version = "0.15.0", features = ["http"] }
Expand All @@ -56,6 +58,7 @@ jaq-std = "1.2.1"
jaq-parse = "1.0.2"
jaq-interpret = { version = "1.2.1", features = ["serde_json"] }
zstd = "0.13.1"
once_cell = "1.20.2"

[dev-dependencies]
tempfile = "3.10.1"
Expand Down
1 change: 1 addition & 0 deletions docs/examples/dedupe-by-url.json
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
"output": "tests/work/url/output"
},
"dedupe": {
"dedupe_method": "documents",
"name": "dedupe_by_url",
"documents": {
"attribute_name": "bff_duplicate_url",
Expand Down
1 change: 1 addition & 0 deletions docs/examples/dedupe-paragraphs.json
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
"output": "tests/work/para/output"
},
"dedupe": {
"dedupe_method": "paragraphs",
"name": "dedupe_paragraphs",
"paragraphs": {
"attribute_name": "bff_duplicate_paragraph_spans"
Expand Down
1 change: 1 addition & 0 deletions docs/getting-started.md
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,7 @@ After tagging, we deduplicate the dataset at a paragraph level.
```shell
dolma dedupe \
--documents "wikipedia/v0/documents/*" \
--dedupe.dedupe_method "paragraphs" \
--dedupe.paragraphs.attribute_name 'bff_duplicate_paragraph_spans' \
--dedupe.skip_empty \
--bloom_filter.file /tmp/deduper_bloom_filter.bin \
Expand Down
76 changes: 57 additions & 19 deletions python/dolma/cli/deduper.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from contextlib import ExitStack
from dataclasses import dataclass
from enum import Enum
from pathlib import Path
from typing import Any, Dict, List, Optional

Expand Down Expand Up @@ -46,6 +47,18 @@ class ParagraphDedupeConfig:
)


@dataclass
class DCLMDedupeConfig:
attribute_name: Optional[str] = field(help="Name of the output field in the tagger")
by_ngram: Optional[NgramDedupeConfig] = field(
default=None, help="Configuration for fuzzy dedupe", default_factory=NgramDedupeConfig
)
paragraph_separator: Optional[str] = field(
default="\n",
help="String to use to separate paragraphs. By default, paragraphs are separated by newlines.",
)


@dataclass
class DocumentDedupeConfig:
attribute_name: Optional[str] = field(help="Name of the output field in the tagger")
Expand Down Expand Up @@ -77,6 +90,15 @@ class BloomFilterConfig:
"estimated_doc_count."
),
)
save_to_disk: bool = field(
default=True, help=("If False, ignore the 'file' field and do NOT save the populated bloom filter to disk")
)
sysram_limit: float = field(
default=0.9,
help=(
"Maximum fraction of the system RAM we use -- will print out a warning if we really want more than this"
),
)


@dataclass
Expand All @@ -88,6 +110,7 @@ class DedupeConfig:
paragraphs: Optional[ParagraphDedupeConfig] = field(
default=None, help="Configuration for paragraph deduplication"
)
dclm: Optional[DCLMDedupeConfig] = field(default=None, help="Configuration for DCLM deduplication")
skip_empty: Optional[bool] = field(default=False, help="If true, empty documents/paragraphs will be skipped")
min_length: Optional[int] = field(default=0, help="Minimum length of documents/paragraphs to be deduplicated")
min_words: Optional[int] = field(
Expand All @@ -99,6 +122,10 @@ class DedupeConfig:
partition_index: Optional[int] = field(
default=0, help="The index of the partition being processed, in the range [0, num_partitions)."
)
dedupe_method: Optional[str] = field(
default=None,
help="Selects which dedupe method to use. Must be either empty or in the set {paragraphs, documents, dclm}",
)


@dataclass
Expand All @@ -108,7 +135,7 @@ class DeduperConfig:
dedupe: DedupeConfig = field(help="Deduplication configuration. Required.")
bloom_filter: BloomFilterConfig = field(help="Bloom filter configuration. Required.")
processes: int = field(
default=1, help="Number of processes to use for deduplication. If 1, no multiprocessing will be used."
default=0, help="Number of processes to use for deduplication. If 1, no multiprocessing will be used."
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why this change?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

0 => means we do the max parallelism (processes becomes number of cores available). I just assumed that we want this behavior almost all of the time

This might not actually play nice with beaker nodes and how CPU's get allocated here. I'll fall back on ai2-best-practices here

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Update the help string to reflect this since it's non-obvious

)
compression: CompressionConfig = field(
default=CompressionConfig(),
Expand All @@ -135,7 +162,6 @@ def run(cls, parsed_config: DeduperConfig):
logger = get_logger("tagger")

dict_config: Dict[str, Any] = {}

with ExitStack() as stack:
work_dirs = stack.enter_context(make_workdirs(parsed_config.work_dir))

Expand All @@ -155,22 +181,29 @@ def run(cls, parsed_config: DeduperConfig):
if dedupe_dict_config["min_words"] < 0:
raise ValueError("min_words must be >= 0")

# add either the document or paragraph dedupe config
if not (
om.is_missing(parsed_config.dedupe.documents, "attribute_name")
and om.is_missing(parsed_config.dedupe.documents, "key")
):
cfg = om.to_container(parsed_config.dedupe.documents)
assert isinstance(cfg, dict), "Expected dedupe.documents to be a dict"
dedupe_dict_config["documents"] = cfg
try_name = try_name or cfg["attribute_name"]
elif not om.is_missing(parsed_config.dedupe.paragraphs, "attribute_name"):
cfg = om.to_container(parsed_config.dedupe.paragraphs)
assert isinstance(cfg, dict), "Expected dedupe.paragraphs to be a dict"
dedupe_dict_config["paragraphs"] = cfg
try_name = try_name or cfg["attribute_name"]
else:
raise ValueError("Either dedupe.documents or dedupe.paragraphs must be specified")
# add either the document or paragraph dedupe config and infer the dedup_method
dedupe_method = parsed_config.dedupe.dedupe_method # If is specified
if dedupe_method == None:
# Else infer the dedupe method:
if not (
om.is_missing(parsed_config.dedupe.documents, "attribute_name")
and om.is_missing(parsed_config.dedupe.documents, "key")
):
dedupe_method = "documents"
elif not (om.is_missing(parsed_config.dedupe.paragraphs, "attribute_name")):
dedupe_method = "paragraphs"
elif not (om.is_missing(parsed_config.dedupe.dclm, "attribute_name")):
dedupe_method = "dclm"
else:
raise ValueError("Some dedupe method must be specified (either explicitly or implicitly)")
dedupe_dict_config["dedupe_method"] = dedupe_method
dedupe_dict_config[dedupe_method] = om.to_container(parsed_config.dedupe[dedupe_method])
assert (
dedupe_dict_config[dedupe_method].get("attribute_name") != None
), "Need attribute name for deduplication"
cfg = om.to_container(parsed_config.dedupe[dedupe_method])
assert isinstance(cfg, dict), "Expected dedupe.%s to be a dict" % dedupe_meth
try_name = try_name or cfg["attribute_name"]

if try_name is None:
raise ValueError("dedupe.name must be specified")
Expand Down Expand Up @@ -215,6 +248,7 @@ def run(cls, parsed_config: DeduperConfig):
"size_in_bytes": int(parsed_config.bloom_filter.size_in_bytes),
"estimated_doc_count": int(parsed_config.bloom_filter.estimated_doc_count),
"desired_false_positive_rate": float(parsed_config.bloom_filter.desired_false_positive_rate),
"save_to_disk": parsed_config.bloom_filter.save_to_disk,
}

if dict_config["bloom_filter"]["size_in_bytes"] <= 0 and (
Expand Down Expand Up @@ -247,7 +281,11 @@ def run(cls, parsed_config: DeduperConfig):
deduper(dict_config)

# upload to remote file if necessary
if not parsed_config.bloom_filter.read_only and not path_is_local:
if (
not parsed_config.bloom_filter.read_only
and not path_is_local
and parsed_config.bloom_filter.save_to_disk
):
print(f"Pushing Bloom filter to {parsed_config.bloom_filter.file}")
local = stack.enter_context(smart_open.open(local_bloom_file, "rb"))
remote = stack.enter_context(smart_open.open(parsed_config.bloom_filter.file, "wb"))
Expand Down
87 changes: 82 additions & 5 deletions src/bloom_filter.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
use ahash::RandomState;
use byteorder::{LittleEndian, NativeEndian, ReadBytesExt, WriteBytesExt};
use human_bytes::human_bytes;
use rand::Rng;
use rayon::prelude::*;
use serde::{Deserialize, Serialize};
use std::collections::VecDeque;
use std::fs::{create_dir_all, OpenOptions};
Expand All @@ -10,6 +12,8 @@ use std::io::{BufReader, BufWriter, Write};
use std::mem::size_of;
use std::path::PathBuf;
use std::sync::atomic::{AtomicU32, Ordering};
use sysinfo::System;

mod bloom_test;
// A thread-safe bloom filter.
pub struct BloomFilter {
Expand Down Expand Up @@ -59,6 +63,61 @@ impl BloomFilter {
size_in_bytes
}

pub fn compute_bloom_size_binsearch(
expected_elements: usize,
fp_rate: f64,
sysram_limit: Option<f64>,
num_hashers: usize,
) -> usize {
/* Uses binary search to get a finer-grained bloom filter size.
If limit_to_system: guarantees that no more than 90% of RAM gets allocated
If num_hashers == 0: computes the optimal number of hashers on the fly
*/

// Get 90% of System RAM and set binsearch bounds
let mut sys = System::new_all();
sys.refresh_all();
let sysram_limit: f64 = sysram_limit.unwrap_or(0.0);
let mut lo = 1 as usize;
let mut hi = if sysram_limit > 0.0 {
((sys.total_memory() as f64) * sysram_limit) as usize
} else {
std::usize::MAX / 8
};

let compute_hashers = num_hashers == 0;
let num_hashers = if num_hashers == 0 {
BloomFilter::optimal_number_of_hashers(hi, expected_elements)
} else {
num_hashers
};

if (sysram_limit > 0.0)
&& BloomFilter::prob_of_false_positive(hi, expected_elements, num_hashers) > fp_rate
{
log::info!("WARNING: TO achieve desired false-positive rate, you'd need >90% of system RAM. Defaulting to {:?} SysRAM", sysram_limit);
return hi;
}

// Do BinSearch
while lo < hi - 1 {
let mid = lo + (hi - lo) / 2;
let num_hashers = if compute_hashers {
BloomFilter::optimal_number_of_hashers(mid, expected_elements)
} else {
num_hashers
};
let computed_fp =
BloomFilter::prob_of_false_positive(mid, expected_elements, num_hashers);
if computed_fp > fp_rate {
lo = mid + 1;
} else {
hi = mid - 1;
}
}
hi
}

#[allow(dead_code)]
pub fn my_prob_of_false_positive(&self, expected_elements: usize) -> f64 {
Self::prob_of_false_positive(
Expand All @@ -68,6 +127,19 @@ impl BloomFilter {
)
}

pub fn calculate_sparsity(&self) -> f64 {
let set_bits: usize = self
.bits
.par_iter()
.map(|atomic| {
let value = atomic.load(std::sync::atomic::Ordering::Relaxed);
value.count_ones() as usize
})
.sum();
let total_bits = self.size_in_bytes() * 8;
(set_bits as f64) / (total_bits as f64)
}

#[allow(dead_code)]
pub fn size_in_bytes(&self) -> usize {
self.bits.len() * size_of::<AtomicU32>()
Expand All @@ -86,8 +158,9 @@ impl BloomFilter {
}

let number_of_u32 = size_in_bytes / size_of::<AtomicU32>();
let bits: Vec<AtomicU32> = std::iter::repeat_with(|| AtomicU32::new(0))
.take(number_of_u32)
let bits = (0..number_of_u32)
.into_par_iter()
.map(|_| AtomicU32::default())
.collect();
Self {
bits,
Expand Down Expand Up @@ -243,11 +316,13 @@ impl BloomFilter {
log::info!("Creating new bloom filter...");
let mut bloom_filter_size: usize = config.size_in_bytes;
if bloom_filter_size == 0 {
bloom_filter_size = BloomFilter::suggest_size_in_bytes(
bloom_filter_size = BloomFilter::compute_bloom_size_binsearch(
config.estimated_doc_count,
config.desired_false_positive_rate,
config.sysram_limit,
0,
);
log::info!("Creating bloom filter with size {} bytes to achieve false positive rate {} for {} elements", bloom_filter_size, config.desired_false_positive_rate, config.estimated_doc_count);
log::info!("Creating bloom filter with size {} bytes to achieve false positive rate {} for {} elements", human_bytes(bloom_filter_size as f64), config.desired_false_positive_rate, config.estimated_doc_count);
}
let num_hashers = BloomFilter::optimal_number_of_hashers(
bloom_filter_size,
Expand All @@ -260,7 +335,7 @@ impl BloomFilter {
);
log::info!(
"Bloom filter will have size {}, {} hashers, false positive rate {}.",
bloom_filter_size,
human_bytes(bloom_filter_size as f64),
num_hashers,
p
);
Expand All @@ -278,4 +353,6 @@ pub struct BloomFilterConfig {
pub read_only: bool,
pub estimated_doc_count: usize,
pub desired_false_positive_rate: f64,
pub save_to_disk: bool,
pub sysram_limit: Option<f64>,
}
Loading