From 90cd0096309d44b60000a667888759d261c1338a Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Sat, 18 Oct 2025 00:07:22 -0400 Subject: [PATCH 01/12] Memmap dataset --- fast_llm/data/dataset/config.py | 32 ++ fast_llm/data/dataset/gpt/config.py | 36 +-- .../gpt/{memmap.py => legacy_memmap.py} | 119 +------ fast_llm/data/dataset/memmap.py | 100 ++++++ fast_llm/data/preparator/gpt_memmap/config.py | 53 +++- .../data/preparator/gpt_memmap/prepare.py | 290 ++++++------------ fast_llm/data/sample/abstract.py | 152 +++++++++ fast_llm/data/sample/language_model.py | 186 ++++++++++- fast_llm/data/sample/range.py | 81 ++++- fast_llm/data/sample/token.py | 86 +++++- fast_llm/data/tokenizer.py | 50 +-- tests/data/test_blending.py | 10 +- tests/data/test_concatenate.py | 6 +- tests/data/test_dataset_from_file.py | 4 +- tests/data/test_fim.py | 6 +- tests/data/test_memmap.py | 14 +- tests/data/test_prepare_gpt_memmap.py | 44 +-- tests/data/test_sampling.py | 10 +- tests/data/test_slice.py | 10 +- tests/models/test_match_megatron.py | 28 +- tests/utils/dataset.py | 50 ++- tests/utils/global_variables.py | 5 +- tests/utils/model_configs.py | 10 +- tools/concatenate_dataset.py | 60 ---- 24 files changed, 899 insertions(+), 543 deletions(-) rename fast_llm/data/dataset/gpt/{memmap.py => legacy_memmap.py} (61%) create mode 100644 fast_llm/data/dataset/memmap.py delete mode 100644 tools/concatenate_dataset.py diff --git a/fast_llm/data/dataset/config.py b/fast_llm/data/dataset/config.py index 20e40b66e..f1bc3472a 100644 --- a/fast_llm/data/dataset/config.py +++ b/fast_llm/data/dataset/config.py @@ -2,6 +2,7 @@ import enum import functools import itertools +import logging import math import pathlib import typing @@ -9,12 +10,15 @@ from fast_llm.config import Config, Field, FieldHint, UpdateType, check_field, config_class from fast_llm.data.dataset.abstract import SamplableDataset, SampledDataset from fast_llm.data.sample.abstract import Sample +from fast_llm.data.sample.language_model import LanguageModelSample from fast_llm.utils import Assert, normalize_probabilities if typing.TYPE_CHECKING: from fast_llm.data.dataset.indexed import ConcatenatedDataset, DatasetSlice, IndexedDataset from fast_llm.engine.distributed.distributed import Distributed +logger = logging.getLogger(__name__) + class ShufflingType(str, enum.Enum): # Shuffle all epochs together. Not extendable. @@ -266,3 +270,31 @@ def build_and_sample( self.weights, sampling, ) + + +@config_class(dynamic_type={SampledDatasetConfig: "memmap"}) +class MemmapDatasetConfig[SampleType: LanguageModelSample](IndexedDatasetConfig[SampleType]): + _abstract: typing.ClassVar[bool] = False + path: pathlib.Path = Field( + default=None, + desc="The path to the dataset, excluding the `.bin` or `.idx` suffix.", + hint=FieldHint.core, + ) + + def build(self) -> "IndexedDataset[SampleType]": + name = str(self.path).replace("/", "__") + if self.path.is_file(): + from fast_llm.data.dataset.memmap import MemmapDataset + + return MemmapDataset[SampleType](name, self.path) + elif self.path.with_suffix(".bin").is_file() and self.path.with_suffix(".idx").is_file(): + logger.warning( + "Using the legacy memmap dataset format." + " This format is deprecated and will be removed in a future release." + " Please recreate the dataset in the new memmap format." + ) + from fast_llm.data.dataset.gpt.legacy_memmap import LegacyMemmapDataset + + return LegacyMemmapDataset[SampleType](name, self.path) + else: + raise FileNotFoundError(self.path) diff --git a/fast_llm/data/dataset/gpt/config.py b/fast_llm/data/dataset/gpt/config.py index 15f54ec80..9ff6654c2 100644 --- a/fast_llm/data/dataset/gpt/config.py +++ b/fast_llm/data/dataset/gpt/config.py @@ -8,19 +8,12 @@ from fast_llm.config import Config, Field, FieldHint, check_field, config_class, skip_valid_if_none from fast_llm.data.config import TokenizerConfig from fast_llm.data.dataset.abstract import SamplableDataset, SampledDataset -from fast_llm.data.dataset.config import ( - IndexedDatasetConfig, - SamplableDatasetConfig, - SampledDatasetConfig, - SamplingData, - SamplingParameters, -) +from fast_llm.data.dataset.config import SamplableDatasetConfig, SampledDatasetConfig, SamplingData, SamplingParameters from fast_llm.data.sample.language_model import LanguageModelSample from fast_llm.utils import Assert if typing.TYPE_CHECKING: from fast_llm.data.dataset.gpt.fim import GPTFimDataset - from fast_llm.data.dataset.gpt.memmap import GPTMemmapDataset from fast_llm.data.dataset.gpt.random import GPTRandomDataset @@ -60,33 +53,6 @@ def build(self) -> "GPTRandomDataset[SampleType]": return GPTRandomDataset[SampleType](self.name) -@config_class(dynamic_type={SampledDatasetConfig: "memmap"}) -class GPTMemmapDatasetConfig[SampleType: LanguageModelSample](IndexedDatasetConfig[SampleType]): - _abstract: typing.ClassVar[bool] = False - path: pathlib.Path = Field( - default=None, - desc="The path to the dataset, excluding the `.bin` or `.idx` suffix.", - hint=FieldHint.core, - ) - num_documents: int | None = Field( - default=None, - desc="Expected number of documents in the dataset.", - hint=FieldHint.optional, - ) - num_tokens: int | None = Field( - default=None, - desc="Expected number of tokens in the dataset.", - hint=FieldHint.optional, - ) - - def build(self) -> "GPTMemmapDataset[SampleType]": - from fast_llm.data.dataset.gpt.memmap import GPTMemmapDataset - - return GPTMemmapDataset[SampleType]( - str(self.path).replace("/", "__"), self.path, self.num_documents, self.num_tokens - ) - - @config_class(dynamic_type={SampledDatasetConfig: "file"}) class GPTDatasetFromFileConfig[SampleType: LanguageModelSample](SamplableDatasetConfig[SampleType]): _abstract: typing.ClassVar[bool] = False diff --git a/fast_llm/data/dataset/gpt/memmap.py b/fast_llm/data/dataset/gpt/legacy_memmap.py similarity index 61% rename from fast_llm/data/dataset/gpt/memmap.py rename to fast_llm/data/dataset/gpt/legacy_memmap.py index 06d8d7acc..d8c63e9f9 100644 --- a/fast_llm/data/dataset/gpt/memmap.py +++ b/fast_llm/data/dataset/gpt/legacy_memmap.py @@ -1,21 +1,19 @@ import pathlib import struct -import typing import numpy as np import torch from fast_llm.data.dataset.gpt.config import GPTSamplingParameters from fast_llm.data.dataset.indexed import IndexedDataset -from fast_llm.data.preparator.gpt_memmap.config import MEMMAP_DTYPES, MEMMAP_DTYPES_INV, MEMMAP_INDEX_HEADER +from fast_llm.data.preparator.gpt_memmap.config import MEMMAP_DTYPES, MEMMAP_INDEX_HEADER from fast_llm.data.sample.language_model import LanguageModelSample from fast_llm.data.sample.range import RangeSample from fast_llm.data.sample.token import TokenSample -from fast_llm.engine.config_utils.data_type import DataType from fast_llm.utils import Assert, div -class GPTMemmapDataset[SampleType: LanguageModelSample](IndexedDataset[SampleType]): +class LegacyMemmapDataset[SampleType: LanguageModelSample](IndexedDataset[SampleType]): """ A memory map dataset, which handles lazy loading of a pre-processed dataset in the Megatron-LM format, i.e. a pair of numpy file containing @@ -28,12 +26,10 @@ def __init__( self, name: str, prefix: pathlib.Path | str, - num_documents: int | None = None, - num_tokens: int | None = None, ): - self._init(name, prefix, num_documents, num_tokens) + self._init(name, prefix) - def _init(self, name: str, prefix: pathlib.Path | str, num_documents: int | None, num_tokens: int | None) -> None: + def _init(self, name: str, prefix: pathlib.Path | str) -> None: super().__init__() self._name = name self._prefix = pathlib.Path(prefix) @@ -54,9 +50,6 @@ def _init(self, name: str, prefix: pathlib.Path | str, num_documents: int | None _ = struct.unpack(" tuple[str, pathlib.Path, int | None, int | None]: - return (self._name, self._prefix, self._num_documents, self._num_tokens) + def __getstate__(self) -> tuple[str, pathlib.Path]: + return (self._name, self._prefix) - def __setstate__(self, state: tuple[str, pathlib.Path, int | None, int | None]): + def __setstate__(self, state: tuple[str, pathlib.Path]): self._init(*state) def __del__(self): @@ -168,7 +159,7 @@ def get_document( token_ids = token_ids.to(torch.int64) if parameters is not None and parameters.use_loss_masking_spans: assert self._spans is not None - # TODO: ====== Store in range format (begin, end) ====== + # Convert to in range format (begin, end). sample_spans = RangeSample( [(begin_, last_ + 1) for begin_, last_ in self._spans[index].tolist()], sample_size ).crop(begin, end) @@ -182,7 +173,7 @@ def get_document( raise ValueError("Failed to read chosen spans from memmap dataset.") elif self._has_preference_spans and self._rejected_spans is None: raise ValueError("Failed to read rejected spans from memmap dataset.") - # TODO: ====== Store in range format ====== + # Convert to in range format (begin, end). chosen_spans = RangeSample( [(self._chosen_spans[index][0].item(), self._chosen_spans[index][1].item() + 1)], sample_size, @@ -222,95 +213,3 @@ def get_document_sizes(self) -> torch.Tensor: def get_document_size(self, index: int) -> int: return self._document_sizes[index].item() - - @classmethod - def write_dataset( - cls, - prefix: pathlib.Path | str, - documents: typing.Iterable[tuple[torch.Tensor, torch.Tensor | None, torch.Tensor | None, torch.Tensor | None]], - ) -> None: - # Initialize metadata - dtype = None - num_documents = 0 - lengths = [] - pointers = [] - offset = 0 - # number of spans for each document - num_spans = [] - spans = [] - chosen_spans = [] - rejected_spans = [] - - prefix = pathlib.Path(prefix) - prefix.parent.mkdir(parents=True, exist_ok=True) - - # Write the binary data file (.bin) lazily - with prefix.with_suffix(".bin").open("wb") as bin_stream: - for token_ids, loss_masking_spans, chosen_span, rejected_span in documents: - # Infer dtype from the first document - if dtype is None: - dtype = token_ids.dtype - assert dtype is not None, "Document dtype could not be inferred from the data." - - # Ensure all documents have the same dtype - assert token_ids.dtype == dtype, f"Expected dtype {dtype}, got {token_ids.dtype}." - - # Write document to binary file - bin_stream.write(token_ids.numpy().tobytes(order="C")) - - # Update metadata - doc_length = len(token_ids) - lengths.append(doc_length) - pointers.append(offset) - if loss_masking_spans is not None: - num_spans.append(len(loss_masking_spans)) - spans.append(loss_masking_spans) - if chosen_span is not None: - chosen_spans.append(chosen_span) - if rejected_span is not None: - rejected_spans.append(rejected_span) - offset += doc_length * dtype.itemsize - num_documents += 1 - - # Finalize metadata arrays - lengths = np.array(lengths, dtype=np.int32) - pointers = np.array(pointers, dtype=np.int64) - num_spans = np.array(num_spans, dtype=np.int32) - if len(spans) > 0: - spans = np.vstack(spans, dtype=np.int32) - else: - spans = np.array(spans, dtype=np.int32) - chosen_spans = np.array(chosen_spans, dtype=np.int32).reshape(-1, 2) - rejected_spans = np.array(rejected_spans, dtype=np.int32).reshape(-1, 2) - - # Write the index file (.idx) - with prefix.with_suffix(".idx").open("wb") as idx_stream: - idx_stream.write(MEMMAP_INDEX_HEADER) - # Indicates the version - # Version 2 optionally adds loss-masking spans - # Version 3 optionally adds chosen/rejected spans - idx_stream.write(struct.pack(" 0 else 0)) - # Flag to indicate whether preference loss-masking spans are present - idx_stream.write(struct.pack(" 0 and rejected_spans.size > 0 else 0)) - # Data type - idx_stream.write(struct.pack(" None: + super().__init__() + self._name = name + self._path = path + + with self._path.open("rb") as stream: + # Very file type. + assert stream.read(len(FILE_HEADER)) == FILE_HEADER + # Go to reader configs. + stream.seek(int.from_bytes(stream.read(4), signed=False)) + # Read the reader config. + reader_config = MemmapIndexDatasetReaderConfig.from_dict( + json.loads(stream.read(int.from_bytes(stream.read(4), signed=False)).decode("utf-8")) + ) + + self._memmap = np.memmap(self._path, mode="r") + # TODO: ===== Check num_documents, num_tokens ====== + self._reader = reader_config.get_reader(memoryview(self._memmap)) + + def __getstate__(self) -> tuple[str, pathlib.Path]: + return (self._name, self._path) + + def __setstate__(self, state: tuple[str, pathlib.Path]): + self._init(*state) + + def __del__(self): + if hasattr(self, "_memmap"): + self._memmap._mmap.close() # noqa + del self._memmap + + def get_document( + self, index: int, begin: int = 0, end: int | None = None, parameters: SamplingParameters | None = None + ) -> SampleType: + return self._reader.get_document(index, begin, end) + + @property + def name(self) -> str: + return self._name + + def __len__(self) -> int: + return self._reader + + # TODO: ====== needed? ====== + # @property + # def num_tokens(self) -> int: + # return self._reader.num_tokens + + def get_document_sizes(self) -> torch.Tensor: + return self._reader.get_document_sizes() + + def get_document_size(self, index: int) -> int: + return self._reader.get_document_size(index) + + @classmethod + def write_dataset(cls, path: pathlib.Path, documents: typing.Iterable[Sample], writer_class: type[MemmapWriter]): + # TODO: Match `writer_class` with `SampleType`? + path.parent.mkdir(parents=True, exist_ok=True) + with path.open("wb") as stream: + # Write the file type header. + stream.write(FILE_HEADER) + # Leave space for a pointer to the reader config. + # We write the config at the end since we don't know it yet. + start = stream.tell() + stream.seek(start + 4) + # Write the data. + reader_config = writer_class.write_dataset(stream, documents) + # Write the reader config. + config_offset = stream.tell() + reader_config_bytes = json.dumps(reader_config.to_dict()).encode("utf-8") + stream.write(len(reader_config_bytes).to_bytes(4, signed=False)) + stream.write(reader_config_bytes) + # Write a pointer to the reader config. + stream.seek(start) + stream.write(config_offset.to_bytes(4, signed=False)) diff --git a/fast_llm/data/preparator/gpt_memmap/config.py b/fast_llm/data/preparator/gpt_memmap/config.py index d2aaee5e2..c193cf942 100644 --- a/fast_llm/data/preparator/gpt_memmap/config.py +++ b/fast_llm/data/preparator/gpt_memmap/config.py @@ -1,3 +1,4 @@ +import functools import os import pathlib import typing @@ -25,14 +26,9 @@ MEMMAP_INDEX_HEADER = b"MMIDIDX\x00\x00" -@config_class(registry=True) -class SourceSchemaConfig(Config): - pass - - -@config_class(dynamic_type={SourceSchemaConfig: "text_column"}) -class TextColumnConfig(SourceSchemaConfig): - input_column: str = Field( +@config_class() +class LanguageModelSourceConfig(Config): + text_column: str = Field( default="text", desc="Field of the dataset to use.", hint=FieldHint.optional, @@ -40,6 +36,38 @@ class TextColumnConfig(SourceSchemaConfig): loss_masking_spans_column: None | str = Field( default=None, desc="Field containing character spans to mask for loss computation", hint=FieldHint.optional ) + chosen_spans_column: None | str = Field( + default=None, desc="Field containing chosen text for preference optimization", hint=FieldHint.optional + ) + rejected_spans_column: None | str = Field( + default=None, desc="Field containing rejected text for preference optimization", hint=FieldHint.optional + ) + + @functools.cached_property + def columns(self) -> list[str]: + columns = [self.text_column] + if self.has_loss_masking_span: + columns.append(self.loss_masking_spans_column) + if self.has_preference_spans: + columns.extend([self.chosen_spans_column, self.rejected_spans_column]) + return columns + + @functools.cached_property + def has_loss_masking_span(self) -> bool: + return self.loss_masking_spans_column is not None + + @functools.cached_property + def has_preference_spans(self) -> bool: + Assert.eq(self.chosen_spans_column is None, self.rejected_spans_column is None) + return self.chosen_spans_column is not None + + def _validate(self): + super()._validate() + if self.has_loss_masking_span != self.rejected_spans_column is not None: + raise ValueError(f"Both chosen and rejected loss masking spans must be specified if one is specified.") + if self.has_preference_spans == self.has_loss_masking_span: + # TODO: ====== Still needed? ====== + raise ValueError(f"Can not enable both loss masking and preference spans.") @config_class() @@ -69,16 +97,10 @@ class GPTHuggingfaceDatasetConfig(Config): desc="Split of the dataset to use.", hint=FieldHint.optional, ) - source_schema: SourceSchemaConfig = Field( + source_schema: LanguageModelSourceConfig = Field( desc="Configuration for the data source.", hint=FieldHint.optional, ) - chosen_text: None | str = Field( - default=None, desc="Field containing chosen text for preference optimization", hint=FieldHint.optional - ) - rejected_text: None | str = Field( - default=None, desc="Field containing rejected text for preference optimization", hint=FieldHint.optional - ) data_type: DataType | None = Field( default=None, desc="Data type of the dataset field." @@ -133,7 +155,6 @@ def _validate(self) -> None: @config_class(dynamic_type={RunnableConfig: "prepare_gpt_memmap", DatasetPreparatorConfig: "gpt_memmap"}) class GPTMemmapDatasetPreparatorConfig(DatasetPreparatorConfig): preparator_name: typing.ClassVar[str] = "gpt_memmap" - output_path: pathlib.Path = Field( default=None, desc="Output directory for the processed dataset.", diff --git a/fast_llm/data/preparator/gpt_memmap/prepare.py b/fast_llm/data/preparator/gpt_memmap/prepare.py index 274bbf1b0..18ab2d787 100644 --- a/fast_llm/data/preparator/gpt_memmap/prepare.py +++ b/fast_llm/data/preparator/gpt_memmap/prepare.py @@ -1,5 +1,6 @@ import json import logging +import math import multiprocessing import pathlib import shutil @@ -18,13 +19,15 @@ BlendedDatasetConfig, DatasetSliceConfig, IndexedDatasetConfig, + MemmapDatasetConfig, SampledDatasetConfig, ) -from fast_llm.data.dataset.gpt.config import GPTMemmapDatasetConfig -from fast_llm.data.dataset.gpt.memmap import GPTMemmapDataset +from fast_llm.data.dataset.memmap import MemmapDataset from fast_llm.data.preparator.config import DatasetPreparator -from fast_llm.data.preparator.gpt_memmap.config import GPTMemmapDatasetPreparatorConfig, TextColumnConfig -from fast_llm.data.sample.language_model import LanguageModelSample +from fast_llm.data.preparator.gpt_memmap.config import GPTMemmapDatasetPreparatorConfig, LanguageModelSourceConfig +from fast_llm.data.sample.language_model import LanguageModelSample, LanguageModelWriter +from fast_llm.data.sample.range import RangeSample +from fast_llm.data.sample.token import TokenSample from fast_llm.data.tokenizer import Tokenizer from fast_llm.engine.config_utils.data_type import DataType, get_unsigned_integer_type from fast_llm.utils import Assert, normalize_probabilities, padded_cumsum @@ -35,154 +38,24 @@ class GPTMemmapDatasetPreparator[ConfigType: GPTMemmapDatasetPreparatorConfig](DatasetPreparator[ConfigType]): _tokenizer: Tokenizer _data_type: DataType - _text_column: str - _loss_masking_spans_column: str | None _sample_type: typing.ClassVar[type[LanguageModelSample]] = LanguageModelSample + _config: GPTMemmapDatasetPreparatorConfig - def _tokenize_batch(self, batch: dict[str, list[typing.Any]]) -> dict[str, list[typing.Any]]: - input_ids = [ - np.array(self._tokenizer.tokenize(text), dtype=self._data_type.numpy) for text in batch[self._text_column] - ] - num_tokens = [len(x) for x in input_ids] - return { - "input_ids": input_ids, - "num_tokens": num_tokens, - } - - def _tokenize_batch_with_spans(self, batch: dict[str, list[typing.Any]]) -> dict[str, list[typing.Any]]: - input_ids, token_spans = map( - list, - zip( - *[ - ( - np.array(input_ids, dtype=self._data_type.numpy), - np.array(token_spans, dtype=np.int32).reshape(-1, 2), - ) - for input_ids, token_spans in [ - self._tokenizer.tokenize_with_spans(text, char_spans) - for text, char_spans in zip(batch[self._text_column], batch[self._loss_masking_spans_column]) - ] - ] - ), - ) - num_tokens = [len(x) for x in input_ids] - return { - "input_ids": input_ids, - "token_spans": token_spans, - "num_tokens": num_tokens, - } - - def _tokenize_preference_batch_with_spans(self, batch: dict[str, list[typing.Any]]) -> dict[str, list[typing.Any]]: - packed_texts = [] - chosen_spans = [] - rejected_spans = [] - - for conv_history, chosen_text, rejected_text in zip( - batch[self._config.dataset.field], - batch[self._config.dataset.chosen_text], - batch[self._config.dataset.rejected_text], - ): - # compute chosen span - full_chosen_text = conv_history + chosen_text + self._tokenizer.tokenizer.eos_token - chosen_span = [len(conv_history), len(full_chosen_text) - 1] - offset = len(full_chosen_text) - chosen_spans.append(chosen_span) + def __init__(self, config: ConfigType): + super().__init__(config) + self._source_shema: LanguageModelSourceConfig = self._config.dataset.source_shema - # compute rejected span - full_rejected_text = self._tokenizer.tokenizer.bos_token + conv_history + rejected_text - rejected_span = [ - offset + len(self._tokenizer.tokenizer.bos_token + conv_history), - offset + len(full_rejected_text) - 1, - ] - rejected_spans.append(rejected_span) + def _save_shard(self, args: tuple[int, datasets.Dataset]) -> MemmapDatasetConfig: + shard_index, shard_dataset = args + file_name = f"shard_{self._config.distributed.rank}_{shard_index}.fast_llm_dataset" - # pack texts - packed_text = full_chosen_text + full_rejected_text - - assert ( - packed_text[chosen_span[0] : chosen_span[1] + 1] == chosen_text + self._tokenizer.tokenizer.eos_token - ), f"{packed_text[chosen_span[0]: chosen_span[1] + 1]} does not match {chosen_text}" - - assert ( - packed_text[rejected_span[0] : rejected_span[1] + 1] == rejected_text - ), f"{packed_text[rejected_span[0]: rejected_span[1] + 1]} does not match {rejected_text}" - packed_texts.append(packed_text) - - # tokenize with spans - input_ids, chosen_token_spans, rejected_token_spans = map( - list, - zip( - *[ - ( - np.array(input_ids, dtype=self._data_type.numpy), - np.array(token_spans[0], dtype=np.int32), - np.array( - [token_spans[1][0], token_spans[1][1] + 1], dtype=np.int32 - ), # adding 1 to end for eos token - ) - for input_ids, token_spans in [ - self._tokenizer.tokenize_with_spans(text, [chosen_span, rejected_span]) - for text, chosen_span, rejected_span in zip(packed_texts, chosen_spans, rejected_spans) - ] - ] - ), + MemmapDataset.write_dataset( + self._config.output_path / file_name, + tqdm.tqdm((sample["sample"] for sample in shard_dataset), desc=f"Saving shard {shard_index}", unit="docs"), + LanguageModelWriter, ) - num_tokens = [len(x) for x in input_ids] - return { - "input_ids": input_ids, - "chosen_token_spans": chosen_token_spans, - "rejected_token_spans": rejected_token_spans, - "num_tokens": num_tokens, - } - - def _save_shard(self, args: tuple[int, datasets.Dataset]) -> GPTMemmapDatasetConfig: - shard_idx, shard_dataset = args - prefix = f"shard_{self._config.distributed.rank}_{shard_idx}" - shard_output_path = self._config.output_path / prefix - - def _document_generator(): - # TODO: Yield `LanguageModelSample` - if "token_spans" in shard_dataset.column_names and self._loss_masking_spans_column is not None: - for item in tqdm.tqdm(shard_dataset, desc=f"Saving shard {shard_idx}", unit="docs"): - yield ( - torch.tensor(item["input_ids"], dtype=self._data_type.torch), - torch.tensor(item["token_spans"], dtype=torch.int32).reshape(-1, 2), - None, - None, - ) - elif ( - "chosen_token_spans" in shard_dataset.column_names - and "rejected_token_spans" in shard_dataset.column_names - and self._config.dataset.chosen_text is not None - and self._config.dataset.rejected_text is not None - ): - for item in tqdm.tqdm(shard_dataset, desc=f"Saving shard {shard_idx}", unit="docs"): - yield ( - torch.tensor(item["input_ids"], dtype=self._data_type.torch), - None, - torch.tensor(item["chosen_token_spans"], dtype=torch.int32).reshape(-1, 2), - torch.tensor(item["rejected_token_spans"], dtype=torch.int32).reshape(-1, 2), - ) - else: - for item in tqdm.tqdm(shard_dataset, desc=f"Saving shard {shard_idx}", unit="docs"): - yield ( - torch.tensor(item["input_ids"], dtype=self._data_type.torch), - None, - None, - None, - ) - - GPTMemmapDataset.write_dataset(prefix=shard_output_path, documents=_document_generator()) - - return GPTMemmapDatasetConfig.from_dict( - { - "type": "memmap", - "path": prefix, - "num_documents": len(shard_dataset), # Use the length of the shard dataset directly - "num_tokens": sum(len(doc["input_ids"]) for doc in shard_dataset), - } - ) + return MemmapDatasetConfig.from_dict({"type": "memmap", "path": file_name}) def _load_dataset(self) -> datasets.Dataset: dataset = datasets.load_dataset( @@ -270,7 +143,11 @@ def run(self) -> None: # Prepare output directory self._config.output_path.mkdir(parents=True, exist_ok=True) - if pathlib.Path(self._config.dataset.path).is_dir(): + downloaded = pathlib.Path(self._config.dataset.path).is_dir() + if self._config.distributed.world_size > 1: + torch.distributed.barrier() + + if downloaded: # Dataset is already downloaded, load from disk dataset = self._load_dataset() else: @@ -296,54 +173,24 @@ def run(self) -> None: index=self._config.distributed.rank, ) - # Set data column and loss masking spans column based on source schema - if isinstance(self._config.dataset.source_schema, TextColumnConfig): - self._text_column = self._config.dataset.source_schema.input_column - self._loss_masking_spans_column = self._config.dataset.source_schema.loss_masking_spans_column - else: - raise ValueError( - f"Dataset source_schema set incorrectly. source_schema: '{self._config.dataset.source_schema}'." - ) - - if self._text_column not in dataset.column_names: - raise ValueError(f"Dataset does not have field '{self._text_column}'.") - - if self._config.dataset.source_schema.loss_masking_spans_column is not None and ( - self._config.dataset.chosen_text is not None or self._config.dataset.rejected_text is not None - ): - raise ValueError(f"Can not enable both loss masking spans and chosen/rejected loss masking spans.") - if (self._config.dataset.chosen_text is None) != (self._config.dataset.rejected_text is None): - raise ValueError(f"Both chosen and rejected loss masking spans must be specified if one is specified.") - - # route tokenize function - if self._loss_masking_spans_column is not None: - if self._loss_masking_spans_column not in dataset.column_names: - raise ValueError(f"Dataset does not have spans field '{self._loss_masking_spans_column}'.") - tokenize_fn = self._tokenize_batch_with_spans - elif self._config.dataset.chosen_text is not None and self._config.dataset.rejected_text is not None: - if self._config.dataset.chosen_text not in dataset.column_names: - raise ValueError(f"Dataset does not have chosen spans field '{self._config.dataset.chosen_text}'.") - if self._config.dataset.rejected_text not in dataset.column_names: - raise ValueError(f"Dataset does not have rejected spans field '{self._config.dataset.rejected_text}'.") - tokenize_fn = self._tokenize_preference_batch_with_spans - else: - tokenize_fn = self._tokenize_batch + for column_name in self._source_shema.columns: + if column_name not in dataset.column_names: + raise ValueError(f"Dataset does not have field '{column_name}'.") # Tokenize the dataset in parallel - tokenized_dataset = dataset.map( - tokenize_fn, + prepared_dataset = dataset.map( + self._prepare_batch, batched=True, num_proc=self._config.tokenize_workers, desc="Tokenizing batches", ) - # Calculate total number of tokens - total_tokens = sum(tqdm.tqdm(tokenized_dataset["num_tokens"], desc="Counting tokens", unit="tokens")) - # Split dataset into shards based on number of tokens - num_shards = int(np.ceil(total_tokens / self._config.tokens_per_shard)) + num_shards = math.ceil( + sum(len(sample) for sample in prepared_dataset["samples"]) / self._config.tokens_per_shard + ) shards = [ - (i, tokenized_dataset.shard(num_shards=num_shards, index=i)) + (i, prepared_dataset.shard(num_shards=num_shards, index=i)) for i in tqdm.tqdm(range(num_shards), desc="Creating shards") ] @@ -353,7 +200,67 @@ def run(self) -> None: self.generate_config_yaml_for_sharded_dst(dataset_configs) - def generate_config_yaml_for_sharded_dst(self, dataset_configs: list[GPTMemmapDatasetConfig]) -> None: + def _prepare_batch(self, batch: dict[str, list[typing.Any]]) -> dict[str, list[LanguageModelSample]]: + # Gather values by sample using zip* + sample_column_values = zip(*(batch[column_name] for column_name in self._source_shema.columns)) + # Convert to dicts using column names. + sample_dicts = ( + {column_name: column_value for column_name, column_value in zip(self._source_shema.columns, sample_data)} + for sample_data in sample_column_values + ) + # Prepare each sample, wrap in dict for the `Dataset` interface + return {"samples": [self._prepare_sample(sample_dict) for sample_dict in sample_dicts]} + + def _prepare_sample(self, sample: dict[str, typing.Any]) -> LanguageModelSample: + text = sample[self._source_shema.text_column] + all_spans = [] + if self._source_shema.has_loss_masking_span: + # TODO: ====== What is the input format? ====== + # Spans are typically stored in the (begin, last) format. We convert to (begin, end) range format. + loss_masking_spans = _sort_spans( + (begin, last + 1) + for begin, last in np.array(sample[self._source_shema.loss_masking_spans_column], dtype=np.int32) + .reshape(-1, 2) + .tolist() + ) + all_spans.extend(loss_masking_spans) + + if self._source_shema.has_preference_spans: + # TODO: ===== Was `self._config.dataset.field` (bug?) ====== + full_chosen_text = ( + text + sample[self._source_shema.chosen_spans_column] + self._tokenizer.tokenizer.eos_token + ) + full_rejected_text = ( + self._tokenizer.tokenizer.bos_token + text + sample[self._source_shema.rejected_spans_column] + ) + # compute chosen span + chosen_spans = [[len(text), len(full_chosen_text)]] + + # compute rejected span + rejected_span = [ + [ + len(full_chosen_text) + len(self._tokenizer.tokenizer.bos_token) + len(text), + len(full_chosen_text) + len(full_rejected_text), + ] + ] + # pack texts + text = full_chosen_text + full_rejected_text + all_spans.extend(chosen_spans + rejected_span) + + tokens = torch.tensor( + self._tokenizer.tokenize_with_spans(text, True, True, spans=_sort_spans(all_spans)), + dtype=self._data_type.torch, + ) + sample_size = len(tokens) + + return LanguageModelSample( + TokenSample(tokens, [sample_size]), + RangeSample(loss_masking_spans, sample_size) if self._source_shema.has_loss_masking_span else None, + RangeSample(chosen_spans, sample_size) if self._source_shema.has_preference_spans else None, + RangeSample(rejected_span, sample_size) if self._source_shema.has_preference_spans else None, + ) + + def generate_config_yaml_for_sharded_dst(self, dataset_configs: list[MemmapDatasetConfig]) -> None: # Gather dataset_dicts from all ranks to rank 0 if self._config.distributed.world_size > 1: if self._config.distributed.rank == 0: @@ -397,7 +304,7 @@ def _save_dataset_config( @classmethod def _blend_dataset_configs( - cls, dataset_configs: list[GPTMemmapDatasetConfig[_sample_type]] + cls, dataset_configs: list[MemmapDatasetConfig[_sample_type]] ) -> IndexedDatasetConfig[_sample_type]: if len(dataset_configs) == 1: return dataset_configs[0] @@ -412,10 +319,11 @@ def _blend_dataset_configs( @classmethod def _split_and_blend_dataset_configs( cls, - dataset_configs: list[GPTMemmapDatasetConfig[_sample_type]], + dataset_configs: list[MemmapDatasetConfig[_sample_type]], splits: dict[str, int | float], output_path: pathlib.Path, ) -> dict[str, SampledDatasetConfig[_sample_type]]: + # TODO: ====== Missing `num_tokens`, `num_documents`. ====== split_cumsum = padded_cumsum(normalize_probabilities(list(splits.values()), return_array=True)).tolist() dataset_sizes = [dataset_config.num_tokens for dataset_config in dataset_configs] dataset_probabilities = normalize_probabilities(dataset_sizes) @@ -483,6 +391,10 @@ def _split_and_blend_dataset_configs( return dataset_splits +def _sort_spans(spans: typing.Iterable[tuple[int, int]]) -> list[tuple[int, int]]: + return sorted(spans, key=lambda span: span[0]) + + def _get_nearest_split(cumsum: np.ndarray, value: float) -> int: left = cumsum.searchsorted(value, side="right") if left == len(cumsum): diff --git a/fast_llm/data/sample/abstract.py b/fast_llm/data/sample/abstract.py index 031002101..f122100f9 100644 --- a/fast_llm/data/sample/abstract.py +++ b/fast_llm/data/sample/abstract.py @@ -1,6 +1,11 @@ import abc +import io +import pathlib import typing +from fast_llm.config import Config, Configurable, Field, config_class +from fast_llm.utils import Assert + if typing.TYPE_CHECKING: import torch @@ -40,3 +45,150 @@ def crop(self, begin: int, end: int) -> typing.Self: def to_device_(self, device: "torch.device | str"): pass + + +@config_class(registry=True) +class MemmapReaderBaseConfig(Config): + """ + Configuration for a memmap reader or reader-like object. + Note: `MemmapDataset` requires a `MemmapIndexedDatasetReader`. + Other readers need to be nested within a `MemmapIndexedDatasetReader` + Note: Reader configs are not typical configs, and do not need to be located in a separate `config.py` file. + """ + + _abstract = True + + @classmethod + def _from_dict(cls, default: dict[str, typing.Any], strict: bool = True) -> typing.Self: + if cls is MemmapReaderBaseConfig and cls.get_subclass(default.get("type")) is None: + # Default subclass, necessary for loading configs where some components could be absent. + return NullReaderConfig._from_dict(default, strict) + return super()._from_dict(default, strict=strict) + + def get_reader(self, buffer: memoryview) -> "MemmapReader|None": + raise NotImplementedError() + + @property + def expected_buffer_size(self) -> int: + """ + The expected buffer size in bytes. Used for self-validation. + """ + raise NotImplementedError() + + +@config_class(dynamic_type={MemmapReaderBaseConfig: "none"}) +class NullReaderConfig(MemmapReaderBaseConfig): + """ + Configuration for a dynamically disabled reader. + """ + + _abstract = False + + def get_reader(self, buffer: memoryview) -> None: + return None + + @property + def expected_buffer_size(self) -> int: + return 0 + + +@config_class(registry=True) +class MemmapReaderConfig(MemmapReaderBaseConfig): + """ + Configuration for a standard memmap reader. + """ + + begin: int = Field() + end: int = Field() + + @property + def reader_class(self) -> "type[MemmapReader]": + raise NotImplementedError() + + def get_reader(self, buffer: memoryview) -> "MemmapReader": + return self.reader_class(self, buffer[self.begin : self.end]) + + @property + def writer_class(self) -> "type[MemmapWriter]": + raise NotImplementedError() + + def get_writer(self, stream: io.BufferedWriter) -> "MemmapWriter": + return self.writer_class(stream) + + def _validate(self): + super()._validate() + print("AAAAA", self.__class__.__name__, self.begin, self.end, self.expected_buffer_size) + Assert.eq(self.end - self.begin, self.expected_buffer_size) + + +@config_class() +class MemmapIndexDatasetReaderConfig(MemmapReaderConfig): + """ + Configuration for a standard memmap reader matching the indexed dataset interface, i.e., + consisting of a list of documents of known lengths. + """ + + @property + def reader_class(self) -> "type[MemmapIndexedDatasetReader]": + raise NotImplementedError() + + def get_reader( + self, + buffer: memoryview, + ) -> "MemmapIndexedDatasetReader": + return self.reader_class(self, buffer[self.begin : self.end]) + + +class MemmapReader[ConfigType: MemmapReaderBaseConfig](Configurable[ConfigType]): + def __init__(self, config: ConfigType, buffer: memoryview): + super().__init__(config) + self._buffer = buffer[self._config.begin : self._config.end] + + @abc.abstractmethod + def get_document(self, index: int, begin: int, end: int) -> Sample: + pass + + +class MemmapIndexedDatasetReader[ConfigType: MemmapIndexDatasetReaderConfig](MemmapReader[ConfigType]): + @abc.abstractmethod + def get_document_sizes(self) -> "torch.Tensor": + pass + + @abc.abstractmethod + def get_document_size(self, index: int) -> int: + pass + + +class MemmapWriter: + def __init__(self, stream: io.BufferedWriter | pathlib.Path): + self._owns_stream = isinstance(stream, pathlib.Path) + if self._owns_stream: + stream = stream.open("wb") + self._stream = stream + + def __enter__(self): + self._begin = self._stream.tell() + return self + + def write(self, document: Sample): + assert hasattr(self, "_begin") and not hasattr(self, "_end") + + def __exit__(self, exc_type, exc_val, exc_tb): + self._end = self._stream.tell() + if self._owns_stream: + self._stream.close() + + def get_config(self, offset: int = 0) -> MemmapReaderConfig: + assert hasattr(self, "_end") + return self._get_config(self._begin + offset, self._end + offset) + + @abc.abstractmethod + def _get_config(self, begin: int, end: int): + pass + + @classmethod + def write_dataset(cls, stream: io.BufferedWriter, documents: typing.Iterable[Sample]) -> MemmapReaderConfig: + with cls(stream) as writer: + for document in documents: + writer.write(document) + return writer.get_config() diff --git a/fast_llm/data/sample/language_model.py b/fast_llm/data/sample/language_model.py index f30188553..3d6964b30 100644 --- a/fast_llm/data/sample/language_model.py +++ b/fast_llm/data/sample/language_model.py @@ -1,8 +1,23 @@ +import io +import pathlib +import tempfile import typing -from fast_llm.data.sample.abstract import Batch, Sample -from fast_llm.data.sample.range import RangeBatch, RangeSample -from fast_llm.data.sample.token import TokenBatch, TokenSample +import torch + +from fast_llm.config import Field, config_class +from fast_llm.data.sample.abstract import ( + Batch, + MemmapIndexDatasetReaderConfig, + MemmapIndexedDatasetReader, + MemmapReaderBaseConfig, + MemmapWriter, + NullReaderConfig, + Sample, +) +from fast_llm.data.sample.range import RangeBatch, RangeSample, RangeWriter +from fast_llm.data.sample.token import TokenBatch, TokenReaderConfig, TokenSample, TokenWriter +from fast_llm.utils import Assert class LanguageModelSample(Sample): @@ -105,3 +120,168 @@ def _merge_optional[T](fn: typing.Callable[[typing.Iterable], T], args: typing.I def _crop_optional[T: Sample | Batch](sample_or_batch: T, begin: int, end: int) -> T | None: return None if sample_or_batch is None else sample_or_batch.crop(begin, end) + + +@config_class(dynamic_type={MemmapReaderBaseConfig: "language_model"}) +class LanguageModelReaderConfig(MemmapIndexDatasetReaderConfig): + _abstract = False + tokens: TokenReaderConfig = Field() + # Using dynamic type for optional readers for enabling/disabling + loss_masking_spans: MemmapReaderBaseConfig = Field() + chosen_spans: MemmapReaderBaseConfig = Field() + rejected_spans: MemmapReaderBaseConfig = Field() + + @property + def reader_class(self) -> "type[LanguageModelReader]": + return LanguageModelReader + + @property + def writer_class(self) -> "type[LanguageModelWriter]": + return LanguageModelWriter + + @property + def expected_buffer_size(self) -> int: + return ( + self.tokens.expected_buffer_size + + self.loss_masking_spans.expected_buffer_size + + self.chosen_spans.expected_buffer_size + + self.rejected_spans.expected_buffer_size + ) + + +class LanguageModelReader[ConfigType: LanguageModelReaderConfig](MemmapIndexedDatasetReader[ConfigType]): + def __init__(self, config: ConfigType, buffer: memoryview): + super().__init__(config, buffer) + # Using `buffer` and not `self._buffer` because nested offsets (`begin`, `end`) are global. + self._tokens = self._config.tokens.get_reader(buffer) + self._loss_masking_spans = self._config.loss_masking_spans.get_reader(buffer) + self._preference_spans = self._config.preference_spans.get_reader(buffer) + + def get_document(self, index: int, begin: int, end: int) -> Sample: + return LanguageModelSample( + self._tokens.get_document(index, begin, end), + self._loss_masking_spans.get_document(index, begin, end), + self._preference_spans.get_document(index, begin, end), + ) + + def get_document_sizes(self) -> torch.Tensor: + return self._tokens.get_document_sizes() + + def get_document_size(self, index: int) -> int: + return self._tokens.get_document_size(index) + + +class LanguageModelWriter(MemmapWriter): + _has_loss_masking_spans: bool | None = None + _has_preference_spans: bool | None = None + + def __enter__(self): + super().__enter__() + self._size_cumsum = [0] + self._data_type = None + + self._directory = tempfile.TemporaryDirectory() + self._path = pathlib.Path(self._directory.name) + # We write intermediate results in separate files so we don't need to iterate over the dataset multiple times. + self._token_writer = TokenWriter(self._path.joinpath("tokens")).__enter__() + self._loss_masking_span_writer = RangeWriter(self._path.joinpath("loss_masking_spans")).__enter__() + self._chosen_spans_writer = RangeWriter(self._path.joinpath("chosen_spans")).__enter__() + self._rejected_spans_writer = RangeWriter(self._path.joinpath("rejected_spans")).__enter__() + return self + + def write(self, document: LanguageModelSample): + # ====== TODO: Make sure input uses end = 1 past last index (currently use last index) ====== + super().write(document) + # Write tokens. + self._token_writer.write(document.tokens) + + # Ensure either all samples have loss masking spans or none of them do. + if self._has_loss_masking_spans is None: + self._has_loss_masking_spans = document.loss_masking_spans is not None + else: + Assert.eq(self._has_loss_masking_spans, document.loss_masking_spans is not None) + + # Write loss masking spans. + if self._has_loss_masking_spans: + self._loss_masking_span_writer.write(document.loss_masking_spans) + + # All sample must either have both chosen and rejected spans, or neither. + if self._has_preference_spans is None: + self._has_preference_spans = document.chosen_spans is not None + else: + Assert.eq(self._has_preference_spans, document.chosen_spans is not None) + Assert.eq(self._has_preference_spans, document.rejected_spans is not None) + + # Write preference spans. + if self._has_preference_spans: + self._chosen_spans_writer.write(document.chosen_spans) + self._rejected_spans_writer.write(document.rejected_spans) + + def __exit__(self, exc_type, exc_val, exc_tb): + self._token_writer.__exit__(exc_type, exc_val, exc_tb) + self._loss_masking_span_writer.__exit__(exc_type, exc_val, exc_tb) + self._chosen_spans_writer.__exit__(exc_type, exc_val, exc_tb) + self._rejected_spans_writer.__exit__(exc_type, exc_val, exc_tb) + + # A dummy config so we can verify the begin and end offsets. + config = self._get_config(self._begin, None) + _copy_chunked(self._path.joinpath("tokens"), self._stream, config.tokens.begin, config.tokens.end) + + if self._has_loss_masking_spans: + _copy_chunked( + self._path.joinpath("loss_masking_spans"), + self._stream, + config.loss_masking_spans.begin, + config.loss_masking_spans.end, + ) + if self._has_preference_spans: + _copy_chunked( + self._path.joinpath("chosen_spans"), self._stream, config.chosen_spans.begin, config.chosen_spans.end + ) + _copy_chunked( + self._path.joinpath("rejected_spans"), + self._stream, + config.rejected_spans.begin, + config.rejected_spans.end, + ) + + self._directory.cleanup() + super().__exit__(exc_type, exc_val, exc_tb) + + def _get_config(self, begin: int, end: int | None): + tokens = self._token_writer.get_config(begin) + offset = tokens.end + if self._has_loss_masking_spans: + loss_masking_spans = self._loss_masking_span_writer.get_config(offset) + offset = loss_masking_spans.end + else: + loss_masking_spans = NullReaderConfig() + if self._has_preference_spans: + chosen_spans = self._chosen_spans_writer.get_config(offset) + offset = chosen_spans.end + rejected_spans = self._rejected_spans_writer.get_config(offset) + offset = rejected_spans.end + else: + chosen_spans = NullReaderConfig() + rejected_spans = NullReaderConfig() + + if end is None: + end = offset + + return LanguageModelReaderConfig( + begin=begin, + end=end, + tokens=tokens, + loss_masking_spans=loss_masking_spans, + chosen_spans=chosen_spans, + rejected_spans=rejected_spans, + ) + + +def _copy_chunked(path: pathlib.Path, stream: io.BufferedWriter, expected_begin: int, expected_end: int): + # Copy temporary file content in chunks of 100 MB. + Assert.eq(stream.tell(), expected_begin) + with path.open("rb") as input_stream: + while data := input_stream.read(100000000): + stream.write(data) + Assert.eq(stream.tell(), expected_end) diff --git a/fast_llm/data/sample/range.py b/fast_llm/data/sample/range.py index d121a38b6..88dd1352d 100644 --- a/fast_llm/data/sample/range.py +++ b/fast_llm/data/sample/range.py @@ -1,6 +1,17 @@ import typing -from fast_llm.data.sample.abstract import Batch, Sample +import numpy as np +import torch + +from fast_llm.config import Field, config_class +from fast_llm.data.sample.abstract import ( + Batch, + MemmapReader, + MemmapReaderBaseConfig, + MemmapReaderConfig, + MemmapWriter, + Sample, +) from fast_llm.utils import get_unique @@ -47,3 +58,71 @@ def from_samples(cls, samples: typing.Iterable[RangeSample]) -> typing.Self: def to_samples(self) -> list[RangeSample]: return [RangeSample(sample_ranges, self.sample_size) for sample_ranges in self.ranges] + + +@config_class(dynamic_type={MemmapReaderBaseConfig: "range"}) +class RangeReaderConfig(MemmapReaderConfig): + _abstract = False + num_documents: int = Field() + num_ranges: int = Field() + + @property + def reader_class(self) -> "type[RangeReader]": + return RangeReader + + @property + def writer_class(self) -> "type[RangeWriter]": + return RangeWriter + + @property + def expected_buffer_size(self) -> int: + return (self.num_ranges + 1) * torch.uint32.itemsize * 2 + (self.num_documents + 1) * torch.uint32.itemsize + + +class RangeReader[ConfigType: RangeReaderConfig](MemmapReader[ConfigType]): + def __init__(self, config: ConfigType, buffer: memoryview): + super().__init__(config, buffer) + self._ranges = torch.frombuffer( + self._buffer, + dtype=torch.uint32, + count=self._config.num_ranges, + ).reshape(-1, 2) + self._count_cumsums = torch.frombuffer( + self._buffer, + dtype=torch.uint32, + count=self._config.num_documents + 1, + offset=self._ranges.nbytes, + ) + + def get(self, index: int, begin: int, end: int) -> RangeSample: + sample_size = end - begin + cropped_ranges = ( + (max(begin_ - begin, 0), min(end_ - begin, sample_size)) + for begin_, end_ in self._ranges[self._count_cumsums[index] : self._count_cumsums[index + 1]].tolist() + ) + return RangeSample([(begin_, end_) for begin_, end_ in cropped_ranges if end_ > begin_], sample_size) + + +class RangeWriter(MemmapWriter): + def __enter__(self): + super().__enter__() + self._count_cumsum = [0] + return self + + def write(self, document: RangeSample): + # ====== TODO: Make sure input uses end = 1 past last index (currently use last index) ====== + super().write(document) + self._stream.write(np.array(document.ranges, dtype=np.uint32).tobytes(order="C")) + self._count_cumsum.append(self._count_cumsum[-1] + len(document.ranges)) + + def __exit__(self, exc_type, exc_val, exc_tb): + self._stream.write(np.array(self._count_cumsum, dtype=np.uint32).tobytes(order="C")) + super().__exit__(exc_type, exc_val, exc_tb) + + def _get_config(self, begin: int, end: int): + return RangeReaderConfig( + begin=begin, + end=end, + num_documents=len(self._count_cumsum) - 1, + num_ranges=self._count_cumsum[-1], + ) diff --git a/fast_llm/data/sample/token.py b/fast_llm/data/sample/token.py index 62d1c0e67..98ee9a2a1 100644 --- a/fast_llm/data/sample/token.py +++ b/fast_llm/data/sample/token.py @@ -1,8 +1,18 @@ import typing +import numpy as np import torch -from fast_llm.data.sample.abstract import Batch, Sample +from fast_llm.config import Field, config_class +from fast_llm.data.sample.abstract import ( + Batch, + MemmapIndexedDatasetReader, + MemmapReaderBaseConfig, + MemmapReaderConfig, + MemmapWriter, + Sample, +) +from fast_llm.engine.config_utils.data_type import DataType from fast_llm.utils import Assert @@ -73,3 +83,77 @@ def crop(self, begin: int, end: int) -> typing.Self: def to_device_(self, device: "torch.device | str"): # Also standardize the dtype while we're here. self.tokens = self.tokens.to(device, dtype=torch.int64, non_blocking=True) + + +@config_class(dynamic_type={MemmapReaderBaseConfig: "token"}) +class TokenReaderConfig(MemmapReaderConfig): + _abstract = False + num_documents: int = Field() + num_tokens: int = Field() + data_type: DataType = Field() + + @property + def reader_class(self) -> "type[TokenReader]": + return TokenReader + + @property + def writer_class(self) -> "type[TokenWriter]": + return TokenWriter + + @property + def expected_buffer_size(self) -> int: + return self.num_tokens * self.data_type.torch.itemsize + (self.num_documents + 1) * torch.uint64.itemsize + + +class TokenReader[ConfigType: TokenReaderConfig](MemmapIndexedDatasetReader[ConfigType]): + def __init__(self, config: ConfigType, buffer: memoryview): + super().__init__(config, buffer) + self._tokens = torch.frombuffer( + self._buffer, + dtype=self._config.data_type.torch, + count=self._config.num_tokens, + ) + self._size_cumsums = torch.frombuffer( + self._buffer, dtype=torch.uint64, count=self._config.num_documents + 1, offset=self._tokens.nbytes + ) + + def get_document(self, index: int, begin: int, end: int) -> Sample: + begin_ = self._size_cumsums[index].item() + return TokenSample(torch.from_numpy(self._tokens[begin_ + begin : begin_ + end]), [end - begin]) + + def get_document_sizes(self) -> torch.Tensor: + return self._size_cumsums[1:] - self._size_cumsums[:-1] + + def get_document_size(self, index: int) -> int: + return self._size_cumsums[index + 1].item() - self._size_cumsums[index].item() + + +class TokenWriter(MemmapWriter): + def __enter__(self): + super().__enter__() + self._size_cumsum = [0] + self._data_type = None + return self + + def write(self, document: TokenSample): + # ====== TODO: Make sure input uses end = 1 past last index (currently use last index) ====== + super().write(document) + if self._data_type is None: + self._data_type = document.tokens.dtype + else: + Assert.eq(self._data_type, document.tokens.dtype) + self._stream.write(document.tokens.numpy().tobytes()) + self._size_cumsum.append(self._size_cumsum[-1] + len(document.tokens)) + + def __exit__(self, exc_type, exc_val, exc_tb): + self._stream.write(np.array(self._size_cumsum, dtype=np.uint64).tobytes(order="C")) + super().__exit__(exc_type, exc_val, exc_tb) + + def _get_config(self, begin: int, end: int): + return TokenReaderConfig( + begin=begin, + end=end, + num_documents=len(self._size_cumsum) - 1, + num_tokens=self._size_cumsum[-1], + data_type=DataType.from_torch(self._data_type), + ) diff --git a/fast_llm/data/tokenizer.py b/fast_llm/data/tokenizer.py index c74586207..71219a2bf 100644 --- a/fast_llm/data/tokenizer.py +++ b/fast_llm/data/tokenizer.py @@ -4,6 +4,7 @@ from fast_llm.data.config import TokenizerConfig from fast_llm.engine.config_utils.run import log_main_rank +from fast_llm.utils import Assert class Tokenizer: @@ -41,7 +42,7 @@ def vocab(self) -> dict[str, int]: def inv_vocab(self) -> dict[int, str]: return self._inv_vocab - def tokenize(self, text: str, begin=True, end=True) -> list[int]: + def tokenize(self, text: str, begin: bool = True, end: bool = True) -> list[int]: return ( ([self.bod_id] if begin else []) + self.tokenizer.encode(text, add_special_tokens=False) @@ -49,36 +50,35 @@ def tokenize(self, text: str, begin=True, end=True) -> list[int]: ) def tokenize_with_spans( - self, text: str, char_spans: list[tuple[int, int]] + self, text: str, begin: bool = True, end: bool = True, *, spans: list[tuple[int, int]] ) -> tuple[list[int], list[tuple[int, int]]]: """ Perform span-aware tokenization and return the tokenized input_ids along with token spans. """ + if not spans: + return self.tokenize(text, begin, end), [] + input_ids, token_splits = self.tokenize_with_splits( + text, begin, end, text_splits=[split for splits in spans for split in splits] + ) + return input_ids, [(begin, end) for begin, end in zip(token_splits[::2], token_splits[1::2], strict=True)] + + def tokenize_with_splits( + self, text: str, begin: bool = True, end: bool = True, *, text_splits: list[int] + ) -> tuple[list[int], list[int]]: + Assert.eq(sorted(text_splits), text_splits) input_ids = [] - token_spans = [] - char_pos = 0 - beginning_of_text = True + text_splits = [0, *text_splits, len(text_splits)] + token_splits = [] + + for split_begin, split_end in zip(text_splits[:-1], text_splits[1:]): + input_ids.extend( + self.tokenize( + text[split_begin:split_end], begin=begin and split_begin == 0, end=end and split_end == len(text) + ) + ) + token_splits.append(len(input_ids)) - for start, end in char_spans: - if char_pos < start: - curr_text = text[char_pos:start] - tokenized_text = self.tokenize(curr_text, begin=beginning_of_text, end=False) - beginning_of_text = False - input_ids.extend(tokenized_text) - curr_text = text[start : end + 1] - if end >= len(text) - 1: - tokenized_text = self.tokenize(curr_text, begin=beginning_of_text, end=True) - else: - tokenized_text = self.tokenize(curr_text, begin=beginning_of_text, end=False) - beginning_of_text = False - token_spans.append((len(input_ids), len(input_ids) + len(tokenized_text) - 1)) - input_ids.extend(tokenized_text) - char_pos = end + 1 - if char_pos < len(text): - curr_text = text[char_pos:] - tokenized_text = self.tokenize(curr_text, begin=beginning_of_text, end=True) - input_ids.extend(tokenized_text) - return input_ids, token_spans + return input_ids, token_splits[:-1] def detokenize(self, token_ids: int | list[int] | np.ndarray | torch.Tensor) -> str: return self.tokenizer.decode(token_ids) diff --git a/tests/data/test_blending.py b/tests/data/test_blending.py index 0099cb50b..49eceee0b 100644 --- a/tests/data/test_blending.py +++ b/tests/data/test_blending.py @@ -13,7 +13,7 @@ get_test_data_and_compare_samples, ) from tests.utils.dataset import get_test_dataset -from tests.utils.global_variables import DATASET_CACHE, DATASET_PREFIX +from tests.utils.global_variables import DATASET_CACHE, DATASET_PATH _DATASET_PREFIX_MIX_1 = DATASET_CACHE / "blended_mix_1" / "dataset" @@ -118,7 +118,7 @@ def test_gpt_blended(): { "type": "blended", "datasets": [ - {"type": "memmap", "path": DATASET_PREFIX}, + {"type": "memmap", "path": DATASET_PATH}, {"type": "memmap", "path": _DATASET_PREFIX_MIX_1}, ], "weights": [0.75, 0.25], @@ -137,7 +137,7 @@ def test_gpt_blended_data(): "training": { "type": "blended", "datasets": [ - {"type": "memmap", "path": DATASET_PREFIX}, + {"type": "memmap", "path": DATASET_PATH}, {"type": "memmap", "path": _DATASET_PREFIX_MIX_1}, ], "weights": [0.75, 0.25], @@ -157,7 +157,7 @@ def test_gpt_blended_mixed(): { "type": "blended", "datasets": [ - {"type": "memmap", "path": DATASET_PREFIX}, + {"type": "memmap", "path": DATASET_PATH}, {"type": "random"}, ], "weights": [0.6, 0.4], @@ -174,7 +174,7 @@ def test_gpt_blended_mixed_data(): "datasets": { "training": { "type": "blended", - "datasets": [{"type": "memmap", "path": DATASET_PREFIX}, {"type": "random"}], + "datasets": [{"type": "memmap", "path": DATASET_PATH}, {"type": "random"}], "weights": [0.6, 0.4], } } diff --git a/tests/data/test_concatenate.py b/tests/data/test_concatenate.py index 5335e01c0..7b009bbf6 100644 --- a/tests/data/test_concatenate.py +++ b/tests/data/test_concatenate.py @@ -9,7 +9,7 @@ ) from tests.data.test_memmap import MEMMAP_DATASET_LENGTH, MEMMAP_DATASET_SAMPLES, MEMMAP_DATASET_TOKENS from tests.utils.dataset import get_test_dataset -from tests.utils.global_variables import DATASET_PREFIX +from tests.utils.global_variables import DATASET_PATH GPT_CONCATENATED_SAMPLES = [ [4709, 819, 79, 207, 277, 1790], @@ -27,7 +27,7 @@ def test_gpt_concatenate(): # Make sure the dataset concatenation works and check for unintended changes in behavior. get_test_dataset() dataset = get_dataset_config( - {"type": "concatenated", "datasets": [{"type": "memmap", "path": DATASET_PREFIX} for _ in range(3)]}, + {"type": "concatenated", "datasets": [{"type": "memmap", "path": DATASET_PATH} for _ in range(3)]}, ConcatenatedDatasetConfig[LanguageModelSample], ).build() compare_indexed_dataset( @@ -47,7 +47,7 @@ def test_gpt_concatenate_data(): "datasets": { "training": { "type": "concatenated", - "datasets": [{"type": "memmap", "path": DATASET_PREFIX} for _ in range(3)], + "datasets": [{"type": "memmap", "path": DATASET_PATH} for _ in range(3)], } } }, diff --git a/tests/data/test_dataset_from_file.py b/tests/data/test_dataset_from_file.py index c149e1395..af91df1e2 100644 --- a/tests/data/test_dataset_from_file.py +++ b/tests/data/test_dataset_from_file.py @@ -2,11 +2,11 @@ from tests.data.common import compare_indexed_dataset, get_dataset_config from tests.data.test_memmap import MEMMAP_DATASET_LENGTH, MEMMAP_DATASET_SAMPLES, MEMMAP_DATASET_TOKENS from tests.utils.dataset import get_test_dataset -from tests.utils.global_variables import DATASET_PREFIX +from tests.utils.global_variables import DATASET_PATH def test_dataset_from_file(): get_test_dataset() - dataset_config = {"type": "file", "path": str(DATASET_PREFIX.parent.joinpath("fast_llm_config.yaml"))} + dataset_config = {"type": "file", "path": str(DATASET_PATH.parent.joinpath("fast_llm_config.yaml"))} dataset = get_dataset_config(dataset_config, GPTDatasetFromFileConfig).build() compare_indexed_dataset(dataset, MEMMAP_DATASET_LENGTH, MEMMAP_DATASET_TOKENS, MEMMAP_DATASET_SAMPLES) diff --git a/tests/data/test_fim.py b/tests/data/test_fim.py index 438c5e7e3..b9dc7fe32 100644 --- a/tests/data/test_fim.py +++ b/tests/data/test_fim.py @@ -6,7 +6,7 @@ get_test_data_and_compare_samples, ) from tests.utils.dataset import get_test_dataset -from tests.utils.global_variables import DATASET_PREFIX, TOKENIZER_PATH +from tests.utils.global_variables import DATASET_PATH, TOKENIZER_PATH GPT_FIM_SAMPLES = [ [4709, 819, 79, 207, 277, 1790], @@ -32,7 +32,7 @@ def test_gpt_fim(): sampled = get_dataset_config( { "type": "fim", - "dataset": {"type": "memmap", "path": DATASET_PREFIX}, + "dataset": {"type": "memmap", "path": DATASET_PATH}, "tokenizer": {"path": TOKENIZER_PATH}, "rate": 0.5, "prefix_token": "w", @@ -52,7 +52,7 @@ def test_gpt_fim_data(): "datasets": { "training": { "type": "fim", - "dataset": {"type": "memmap", "path": DATASET_PREFIX}, + "dataset": {"type": "memmap", "path": DATASET_PATH}, "tokenizer": {"path": TOKENIZER_PATH}, "rate": 0.5, "prefix_token": "w", diff --git a/tests/data/test_memmap.py b/tests/data/test_memmap.py index ca887f3c1..419b67903 100644 --- a/tests/data/test_memmap.py +++ b/tests/data/test_memmap.py @@ -2,10 +2,10 @@ import pytest -from fast_llm.data.dataset.gpt.config import GPTMemmapDatasetConfig +from fast_llm.data.dataset.config import MemmapDatasetConfig from tests.data.common import compare_indexed_dataset, get_dataset_config from tests.utils.dataset import get_test_dataset -from tests.utils.global_variables import DATASET_CACHE, DATASET_PREFIX, DATASET_SAMPLING_CACHE +from tests.utils.global_variables import DATASET_PATH, DATASET_SAMPLING_CACHE, DATASET_WITH_SPANS_PATH MEMMAP_DATASET_LENGTH = 6153 MEMMAP_DATASET_TOKENS = 508327 @@ -21,7 +21,7 @@ def test_gpt_memmap(cache_directory): # Make sure the memmap dataset works and check for unintended changes in behavior. get_test_dataset() - dataset = get_dataset_config({"type": "memmap", "path": DATASET_PREFIX}, GPTMemmapDatasetConfig).build() + dataset = get_dataset_config({"type": "memmap", "path": DATASET_PATH}, MemmapDatasetConfig).build() compare_indexed_dataset(dataset, MEMMAP_DATASET_LENGTH, MEMMAP_DATASET_TOKENS, MEMMAP_DATASET_SAMPLES) @@ -32,17 +32,15 @@ def test_gpt_memmap(cache_directory): 15: [], } -_DATASET_PREFIX_SPANS = DATASET_CACHE / "with_spans" / "dataset" - def test_gpt_data_with_spans(): - get_test_dataset(prefix=_DATASET_PREFIX_SPANS, max_spans=5) + get_test_dataset(DATASET_WITH_SPANS_PATH, max_spans=5) dataset = get_dataset_config( { "type": "memmap", - "path": _DATASET_PREFIX_SPANS, + "path": DATASET_WITH_SPANS_PATH, }, - GPTMemmapDatasetConfig, + MemmapDatasetConfig, ).build() compare_indexed_dataset( dataset, MEMMAP_DATASET_LENGTH, MEMMAP_DATASET_TOKENS, MEMMAP_DATASET_SAMPLES, MEMMAP_DATASET_SPANS diff --git a/tests/data/test_prepare_gpt_memmap.py b/tests/data/test_prepare_gpt_memmap.py index 601abcf99..1608bb48c 100644 --- a/tests/data/test_prepare_gpt_memmap.py +++ b/tests/data/test_prepare_gpt_memmap.py @@ -8,10 +8,12 @@ from fast_llm.data.dataset.config import IndexedDatasetConfig from fast_llm.data.dataset.gpt.config import GPTSamplingParameters -from fast_llm.data.dataset.gpt.memmap import GPTMemmapDataset +from fast_llm.data.dataset.memmap import MemmapDataset from fast_llm.data.preparator.gpt_memmap.config import MEMMAP_DTYPES, GPTMemmapDatasetPreparatorConfig from fast_llm.data.preparator.gpt_memmap.prepare import GPTMemmapDatasetPreparator -from fast_llm.data.sample.language_model import LanguageModelSample +from fast_llm.data.sample.language_model import LanguageModelSample, LanguageModelWriter +from fast_llm.data.sample.range import RangeSample +from fast_llm.data.sample.token import TokenSample from fast_llm.utils import Assert from tests.data.common import MockGPTMemmapDatasetConfig # Noqa @@ -31,15 +33,17 @@ def get_preparator(output_path: str, dataset_path_name: str) -> GPTMemmapDataset @pytest.mark.parametrize("dtype", MEMMAP_DTYPES.values()) def test_write_memmap_dataset(dtype): documents = [ - (torch.from_numpy(np.random.randint(1000, size=np.random.randint(1, 100)).astype(dtype)), None, None, None) + LanguageModelSample( + TokenSample(torch.from_numpy(np.random.randint(1000, size=np.random.randint(1, 100)).astype(dtype))) + ) for _ in range(100) ] with tempfile.TemporaryDirectory() as temp_dir: - prefix = pathlib.Path(temp_dir) - GPTMemmapDataset.write_dataset(prefix=prefix, documents=documents) - dataset = GPTMemmapDataset(name="foo", prefix=prefix) - for i, (tokens, _, _, _) in enumerate(documents): - Assert.all_equal(dataset.get_document(i).tokens.tokens, tokens.to(torch.int64)) + path = pathlib.Path(temp_dir) / "dataset" + MemmapDataset.write_dataset(path, documents, LanguageModelWriter) + dataset = MemmapDataset("dataset", path) + for i, document in enumerate(documents): + Assert.all_equal(dataset.get_document(i).tokens.tokens, document.tokens.tokens.to(torch.int64)) def _generate_valid_span(max_seq_length): @@ -49,26 +53,26 @@ def _generate_valid_span(max_seq_length): @pytest.mark.parametrize("dtype", MEMMAP_DTYPES.values()) def test_write_memmap_preference_dataset(dtype): documents = [ - ( - torch.from_numpy(np.random.randint(1000, size=100).astype(dtype)), + LanguageModelSample( + TokenSample(torch.from_numpy(np.random.randint(1000, size=100).astype(dtype))), None, - _generate_valid_span(100), - _generate_valid_span(100), + RangeSample(_generate_valid_span(100), 100), + RangeSample(_generate_valid_span(100), 100), ) for _ in range(50) ] with tempfile.TemporaryDirectory() as temp_dir: - prefix = pathlib.Path(temp_dir) - GPTMemmapDataset.write_dataset(prefix=prefix, documents=documents) - dataset = GPTMemmapDataset(name="foo", prefix=prefix) + path = pathlib.Path(temp_dir) / "dataset" + MemmapDataset.write_dataset(path, documents, LanguageModelWriter) + dataset = MemmapDataset("dataset", path) parameters = GPTSamplingParameters( num_samples=0, sequence_length=0, vocab_size=0, use_preference_loss_spans=True ) - for i, (token_ids, _, (chosen_begin, chosen_end), (rejected_begin, rejected_end)) in enumerate(documents): - document = dataset.get_document(i, parameters=parameters) - Assert.all_equal(document.tokens.tokens, token_ids.to(torch.int64)) - Assert.eq(document.chosen_spans.ranges, [(chosen_begin, chosen_end + 1)]) - Assert.eq(document.rejected_spans.ranges, [(rejected_begin, rejected_end + 1)]) + for i, document in enumerate(documents): + dataset_document = dataset.get_document(i, parameters=parameters) + Assert.all_equal(dataset_document.tokens.tokens, document.tokens.tokens.to(torch.int64)) + Assert.eq(dataset_document.chosen_spans.ranges, document.chosen_spans.ranges) + Assert.eq(dataset_document.rejected_spans.ranges, document.rejected_spans.ranges) def test_load_metadata_from_hub(): diff --git a/tests/data/test_sampling.py b/tests/data/test_sampling.py index 58f4d3dab..c171d15dd 100644 --- a/tests/data/test_sampling.py +++ b/tests/data/test_sampling.py @@ -2,8 +2,8 @@ import pytest import torch -from fast_llm.data.dataset.config import ShufflingType -from fast_llm.data.dataset.gpt.config import GPTMemmapDatasetConfig, GPTSamplingParameters +from fast_llm.data.dataset.config import MemmapDatasetConfig, ShufflingType +from fast_llm.data.dataset.gpt.config import GPTSamplingParameters from fast_llm.data.dataset.indexed import IndexedDataset from fast_llm.data.sample.language_model import LanguageModelSample from fast_llm.data.sample.token import TokenSample @@ -15,7 +15,7 @@ validate_indexed_dataset_sampling, ) from tests.utils.dataset import get_test_dataset -from tests.utils.global_variables import DATASET_PREFIX +from tests.utils.global_variables import DATASET_PATH try: from fast_llm.csrc.data import build_padded_token_cumsum # noqa @@ -40,7 +40,7 @@ def test_gpt_sampled(): # Make sure the memmap dataset works and check for unintended changes in behavior. get_test_dataset() - sampled = get_dataset_config({"type": "memmap", "path": DATASET_PREFIX}, GPTMemmapDatasetConfig).build_and_sample( + sampled = get_dataset_config({"type": "memmap", "path": DATASET_PATH}, MemmapDatasetConfig).build_and_sample( get_sampling_data(8, sequence_length=5) ) validate_indexed_dataset_sampling(sampled, GPT_MEMMAP_SAMPLES) @@ -53,7 +53,7 @@ def test_gpt_sampled_data(): "datasets": { "training": { "type": "memmap", - "path": DATASET_PREFIX, + "path": DATASET_PATH, } } }, diff --git a/tests/data/test_slice.py b/tests/data/test_slice.py index 3c6ae10d4..3a6b999cd 100644 --- a/tests/data/test_slice.py +++ b/tests/data/test_slice.py @@ -9,7 +9,7 @@ ) from tests.data.test_memmap import MEMMAP_DATASET_SAMPLES from tests.utils.dataset import get_test_dataset -from tests.utils.global_variables import DATASET_PREFIX +from tests.utils.global_variables import DATASET_PATH GPT_SLICE_TRAINING_SAMPLES = [ [80, 268, 79, 260, 207, 3086], @@ -34,7 +34,7 @@ def test_gpt_slice(): get_test_dataset() # samples[9:18] dataset = get_dataset_config( - {"type": "slice", "dataset": {"type": "memmap", "path": DATASET_PREFIX}, "begin": 0.0015, "end": 0.003}, + {"type": "slice", "dataset": {"type": "memmap", "path": DATASET_PATH}, "begin": 0.0015, "end": 0.003}, DatasetSliceConfig[LanguageModelSample], ).build() compare_indexed_dataset(dataset, 9, 544, {i - 9: sample for i, sample in MEMMAP_DATASET_SAMPLES.items()}) @@ -48,19 +48,19 @@ def test_gpt_slice_data(): "datasets": { "training": { "type": "slice", - "dataset": {"type": "memmap", "path": DATASET_PREFIX}, + "dataset": {"type": "memmap", "path": DATASET_PATH}, "begin": 0, "end": 0.0015, }, "validation": { "type": "slice", - "dataset": {"type": "memmap", "path": DATASET_PREFIX}, + "dataset": {"type": "memmap", "path": DATASET_PATH}, "begin": 0.0015, "end": 0.003, }, "test": { "type": "slice", - "dataset": {"type": "memmap", "path": DATASET_PREFIX}, + "dataset": {"type": "memmap", "path": DATASET_PATH}, "begin": 0.003, "end": 1, }, diff --git a/tests/models/test_match_megatron.py b/tests/models/test_match_megatron.py index 7447e395a..42a7c1f0d 100644 --- a/tests/models/test_match_megatron.py +++ b/tests/models/test_match_megatron.py @@ -6,16 +6,16 @@ from fast_llm.config import Field, FieldHint, config_class from fast_llm.data.dataset.abstract import SampledDataset -from fast_llm.data.dataset.config import SampledDatasetConfig -from fast_llm.data.dataset.gpt.config import GPTMemmapDatasetConfig, GPTSamplingData -from fast_llm.data.dataset.gpt.memmap import GPTMemmapDataset +from fast_llm.data.dataset.config import MemmapDatasetConfig, SampledDatasetConfig +from fast_llm.data.dataset.gpt.config import GPTSamplingData +from fast_llm.data.dataset.gpt.legacy_memmap import LegacyMemmapDataset from fast_llm.data.dataset.sampled import logger from fast_llm.data.sample.language_model import LanguageModelSample from fast_llm.utils import Assert from tests.utils.compare_tensor_logs import CompareConfig from tests.utils.dataset import get_model_test_dataset from tests.utils.distributed_configs import DistributedTestingConfig -from tests.utils.global_variables import MODEL_DATASET_PREFIX +from tests.utils.global_variables import MODEL_DATASET_PATH from tests.utils.model_configs import ModelTestingGroup from tests.utils.utils import requires_cuda @@ -69,7 +69,7 @@ def test_match_megatron(run_test_script_for_all_models, model_testing_config, co compare="megatron", config_args=[ "model.distributed.compute_dtype=fp32", - f'data.datasets.training={{"type":"megatron","path":{MODEL_DATASET_PREFIX}}}', + f'data.datasets.training={{"type":"megatron","path":{MODEL_DATASET_PATH}}}', "data.sampling.seed=1234", "model.base_model.use_megatron_initialization=True", ], @@ -82,25 +82,23 @@ def test_match_megatron(run_test_script_for_all_models, model_testing_config, co @config_class(dynamic_type={SampledDatasetConfig: "megatron"}) -class GPTMegatronDatasetConfig(GPTMemmapDatasetConfig): +class MegatronDatasetConfig[SampleType: LanguageModelSample](MemmapDatasetConfig[SampleType]): _abstract: typing.ClassVar[bool] = False path: str = Field( desc="Dataset path (prefix).", hint=FieldHint.core, ) - def build(self) -> "GPTMemmapDataset": - return GPTMegatronMemmapDataset( - str(self.path).replace("/", "__"), self.path, self.num_documents, self.num_tokens - ) + def build(self) -> "LegacyMemmapDataset[SampleType]": + return MegatronMemmapDataset(str(self.path).replace("/", "__"), self.path) -class GPTMegatronMemmapDataset(GPTMemmapDataset): - def sample(self, sampling: GPTSamplingData) -> "MegatronGPTSampledIndexedDataset": - return MegatronGPTSampledIndexedDataset(self, sampling) +class MegatronMemmapDataset(LegacyMemmapDataset): + def sample(self, sampling: GPTSamplingData) -> "MegatronSampledIndexedDataset": + return MegatronSampledIndexedDataset(self, sampling) -class MegatronGPTSampledIndexedDataset(SampledDataset): +class MegatronSampledIndexedDataset(SampledDataset): """ A GPT sampled dataset that exactly matches Megatron-LM, for testing purposes. Minimalistic implementation, implements only the required features. @@ -108,7 +106,7 @@ class MegatronGPTSampledIndexedDataset(SampledDataset): def __init__( self, - indexed_dataset: GPTMegatronMemmapDataset, + indexed_dataset: MegatronMemmapDataset, sampling: GPTSamplingData, ): assert isinstance(sampling, GPTSamplingData) diff --git a/tests/utils/dataset.py b/tests/utils/dataset.py index 428dec56b..baff00b80 100644 --- a/tests/utils/dataset.py +++ b/tests/utils/dataset.py @@ -5,10 +5,13 @@ import torch import yaml -from fast_llm.data.dataset.gpt.memmap import GPTMemmapDataset +from fast_llm.data.dataset.memmap import MemmapDataset +from fast_llm.data.sample.language_model import LanguageModelSample, LanguageModelWriter +from fast_llm.data.sample.range import RangeSample +from fast_llm.data.sample.token import TokenSample from tests.utils.global_variables import ( - DATASET_PREFIX, - MODEL_DATASET_PREFIX, + DATASET_PATH, + MODEL_DATASET_PATH, MODEL_TEST_VOCAB_SIZE, TEST_CHARACTERS, TEST_DATASET_TOKENS, @@ -35,7 +38,7 @@ def get_random_spans(num_samples: int, max_spans: int, lengths: np.ndarray | int def get_test_dataset( - prefix: pathlib.Path = DATASET_PREFIX, + path: pathlib.Path = DATASET_PATH, seed: int = 1234, num_tokens: int = TEST_DATASET_TOKENS, characters: str = TEST_CHARACTERS, @@ -43,48 +46,35 @@ def get_test_dataset( max_spans: int = 0, ): download_santacoder_tokenizer() + config_path = path.parent.joinpath("fast_llm_config.yaml") - if not ( - prefix.with_suffix(".idx").is_file() - and prefix.with_suffix(".bin").is_file() - and prefix.parent.joinpath("fast_llm_config.yaml").is_file() - ): + if not (path.is_file() and config_path.is_file()): import transformers texts = "".join(random.Random(seed).choices(characters, k=num_tokens)).splitlines() tokenizer = transformers.AutoTokenizer.from_pretrained(TOKENIZER_PATH) samples = [ - ( - torch.from_numpy(np.array(tokenizer(document)["input_ids"], dtype=np.uint16) % vocab_size), - None, - None, - None, + LanguageModelSample( + TokenSample( + torch.from_numpy(np.array(tokenizer(document)["input_ids"], dtype=np.uint16) % vocab_size) + ), ) for document in texts ] if max_spans > 0: spans = get_random_spans( - len(samples), max_spans, np.array([[max(len(tokens), 1)] for tokens, _, _, _ in samples]), seed + len(samples), max_spans, np.array([[max(len(sample), 1)] for sample in samples]), seed ) - samples = [ - ( - tokens, - torch.tensor(sample_spans, dtype=torch.int32).reshape(-1, 2), - None, - None, - ) - for (tokens, _, _, _), sample_spans in zip(samples, spans, strict=True) - ] + for sample, sample_spans in zip(samples, spans, strict=True): + sample.loss_masking_spans = RangeSample(sample_spans, len(sample)) - GPTMemmapDataset.write_dataset(prefix, samples) - yaml.safe_dump( - {"type": "memmap", "path": prefix.name}, prefix.parent.joinpath("fast_llm_config.yaml").open("w") - ) + MemmapDataset.write_dataset(path, samples, LanguageModelWriter) + yaml.safe_dump({"type": "memmap", "path": path.name}, config_path.open("w")) def get_model_test_dataset( - prefix: pathlib.Path = MODEL_DATASET_PREFIX, + path: pathlib.Path = MODEL_DATASET_PATH, vocab_size: int = MODEL_TEST_VOCAB_SIZE, ): - return get_test_dataset(prefix=prefix, vocab_size=vocab_size) + return get_test_dataset(path, vocab_size=vocab_size) diff --git a/tests/utils/global_variables.py b/tests/utils/global_variables.py index 42e588911..c62903a6c 100644 --- a/tests/utils/global_variables.py +++ b/tests/utils/global_variables.py @@ -37,12 +37,13 @@ def set_testing_global_variables(): TOKENIZER_PATH = SHARED_RESULT_PATH / "tokenizer" TOKENIZER_FILE = TOKENIZER_PATH / "tokenizer.json" DATASET_CACHE = SHARED_RESULT_PATH / "dataset" -DATASET_PREFIX = DATASET_CACHE / "common_dataset" +DATASET_PATH = DATASET_CACHE / "common_dataset.fast_llm_dataset" +DATASET_WITH_SPANS_PATH = DATASET_CACHE / "dataset_with_spans.fast_llm_dataset" DATASET_SAMPLING_CACHE = TEST_RESULTS_PATH / "dataset_sampling_cache" TEST_VOCAB_SIZE = 8192 # Random lowercase: 80.7% (3.1% each); space: 18.6%; doc end: 0.6% TEST_CHARACTERS = (string.ascii_lowercase) * 5 + " " * 30 + "\n" TEST_DATASET_TOKENS = 1000000 -MODEL_DATASET_PREFIX = DATASET_CACHE / "model_dataset" +MODEL_DATASET_PATH = DATASET_CACHE / "model_dataset.fast_llm_dataset" MODEL_TEST_VOCAB_SIZE = 384 diff --git a/tests/utils/model_configs.py b/tests/utils/model_configs.py index c02521d7b..adcf84b18 100644 --- a/tests/utils/model_configs.py +++ b/tests/utils/model_configs.py @@ -22,7 +22,7 @@ Qwen2CheckpointFormat, ) from tests.utils.distributed_configs import DistributedTestingConfig -from tests.utils.global_variables import MODEL_DATASET_PREFIX, MODEL_TEST_VOCAB_SIZE +from tests.utils.global_variables import MODEL_DATASET_PATH, MODEL_TEST_VOCAB_SIZE from fast_llm.engine.evaluation.evaluators import ( # isort:skip # needed for dynamic type registration EvaluatorsConfig, @@ -234,18 +234,18 @@ def _update_and_add_testing_config( "data": { "datasets": { "training": { - "dataset": {"type": "memmap", "path": MODEL_DATASET_PREFIX}, + "dataset": {"type": "memmap", "path": MODEL_DATASET_PATH}, "type": "slice", "end": 0.969, }, "validation": { - "dataset": {"type": "memmap", "path": MODEL_DATASET_PREFIX}, + "dataset": {"type": "memmap", "path": MODEL_DATASET_PATH}, "type": "slice", "begin": 0.969, "end": 0.999, }, "test": { - "dataset": {"type": "memmap", "path": MODEL_DATASET_PREFIX}, + "dataset": {"type": "memmap", "path": MODEL_DATASET_PATH}, "type": "slice", "begin": 0.999, "end": 1, @@ -279,7 +279,7 @@ def _update_and_add_testing_config( "--tokenizer-type=NullTokenizer", # Megatron messes with the vocab size, so we have to subtract 1. f"--vocab-size={MODEL_TEST_VOCAB_SIZE - 1}", - f"--data-path={MODEL_DATASET_PREFIX}", + f"--data-path={MODEL_DATASET_PATH}", "--split=1,0,0", "--lr-decay-style=constant", # Initialization is set up to match MCore models (MCore inverts self-attn qkv and dense layers compared to original Megatron) diff --git a/tools/concatenate_dataset.py b/tools/concatenate_dataset.py deleted file mode 100644 index bbfa4b21a..000000000 --- a/tools/concatenate_dataset.py +++ /dev/null @@ -1,60 +0,0 @@ -import json -import logging -import pathlib - -from fast_llm.config import Field, config_class -from fast_llm.data.dataset.gpt.memmap import GPTMemmapDataset -from fast_llm.engine.config_utils.runnable import RunnableConfig - -logger = logging.getLogger(__name__) - - -@config_class() -class ConcatenateDatasetConfig(RunnableConfig): - directory: pathlib.Path = Field() - output_name: str = Field(default="fast_llm_dataset.json") - # A lower bound on the number of tokens in a dataset. - # Normally we would like each dataset split to contain at least a few samples, - # i.e. we want num_tokens >= sequence_length * min_split * min_samples_per_split. - # For example with a (999, 1, 0) split , 8K sequence length, we need at least 8M tokens - # for a single validation sample, possibly more if the split is imperfect. - min_tokens: int | None = Field(default=None) - - def run(self): - self.to_logs() - assert self.directory.is_dir() - output_file = self.directory / self.output_name - assert not output_file.exists(), str(output_file) - datasets = [] - - logger.info(f"Loading datasets from {self.directory}") - for path in self.directory.glob("**/*.idx"): - prefix = path.with_suffix("") - logger.info(str(prefix)) - dataset = GPTMemmapDataset("dataset", prefix) - dataset_dict = { - "prefix": str(prefix.relative_to(self.directory)), - "num_documents": len(dataset), - "num_tokens": dataset.num_tokens, - } - if self.min_tokens is not None and dataset_dict["num_tokens"] < self.min_tokens: - logger.info( - f"Ignoring dataset {dataset_dict['prefix']} with {dataset_dict['num_tokens']:,} tokens" - f" (requiring at least {self.min_tokens:,} tokens)" - ) - else: - datasets.append(dataset_dict) - total_documents = sum(dataset["num_documents"] for dataset in datasets) - total_tokens = sum(dataset["num_tokens"] for dataset in datasets) - logger.info(f"Found {total_documents:,} documents, {total_tokens:,} tokens in {len(datasets)} dataset files") - for dataset in datasets: - dataset["weight"] = dataset["num_tokens"] / total_tokens - logger.info( - f'{dataset["prefix"]}: documents = {dataset["num_documents"]:,}, tokens = {dataset["num_tokens"]:,}, weight = {dataset["weight"]:.6f}' - ) - logger.info(f"Saving merged dataset to {output_file}") - json.dump({"datasets": datasets}, output_file.open("w")) - - -if __name__ == "__main__": - ConcatenateDatasetConfig.parse_and_run() From acfd30ea476c0c11a5ff8233aaa71ea2e5814956 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Wed, 29 Oct 2025 17:57:02 -0400 Subject: [PATCH 02/12] fixes --- fast_llm/data/dataset/config.py | 2 +- fast_llm/data/dataset/indexed.py | 19 ++++- fast_llm/data/dataset/memmap.py | 16 ++-- fast_llm/data/preparator/gpt_memmap/config.py | 2 +- .../data/preparator/gpt_memmap/prepare.py | 80 +++++++++++-------- fast_llm/data/sample/abstract.py | 58 ++++++++++++-- fast_llm/data/sample/language_model.py | 31 +++++-- fast_llm/data/sample/range.py | 23 +++--- fast_llm/data/sample/token.py | 20 +++-- tests/data/common.py | 17 ++-- tests/data/test_blending.py | 8 +- tests/data/test_memmap.py | 4 +- tests/data/test_prepare_gpt_memmap.py | 15 ++-- 13 files changed, 207 insertions(+), 88 deletions(-) diff --git a/fast_llm/data/dataset/config.py b/fast_llm/data/dataset/config.py index f1bc3472a..f60decd81 100644 --- a/fast_llm/data/dataset/config.py +++ b/fast_llm/data/dataset/config.py @@ -10,11 +10,11 @@ from fast_llm.config import Config, Field, FieldHint, UpdateType, check_field, config_class from fast_llm.data.dataset.abstract import SamplableDataset, SampledDataset from fast_llm.data.sample.abstract import Sample -from fast_llm.data.sample.language_model import LanguageModelSample from fast_llm.utils import Assert, normalize_probabilities if typing.TYPE_CHECKING: from fast_llm.data.dataset.indexed import ConcatenatedDataset, DatasetSlice, IndexedDataset + from fast_llm.data.sample.language_model import LanguageModelSample from fast_llm.engine.distributed.distributed import Distributed logger = logging.getLogger(__name__) diff --git a/fast_llm/data/dataset/indexed.py b/fast_llm/data/dataset/indexed.py index c6eac9e28..5d6636f7f 100644 --- a/fast_llm/data/dataset/indexed.py +++ b/fast_llm/data/dataset/indexed.py @@ -34,11 +34,20 @@ def get_document( ) -> SampleType: pass - @abc.abstractmethod def __len__(self) -> int: """ - Number of samples in the dataset. + Number of documents in the dataset. + Note: this default implementation is slow and should be overridden when possible. + """ + return len(self.get_document_sizes()) + + @property + def num_tokens(self) -> int: + """ + Number of tokens in the dataset. + Note: this default implementation is slow and should be overridden when possible. """ + return self.get_document_sizes().sum().item() def sample(self, sampling: SamplingData) -> "GPTSampledIndexedDataset": from fast_llm.data.dataset.sampled import SampledIndexedDataset @@ -108,6 +117,12 @@ def __init__( def __len__(self) -> int: return self._dataset_splits[-1].item() + def num_tokens(self) -> int: + """ + Number of tokens in the dataset. + """ + return sum(len(dataset) for dataset in self._datasets) + def get_document_sizes(self) -> torch.Tensor: # TODO: This can be really big. return torch.cat([dataset.get_document_sizes() for dataset in self._datasets]) diff --git a/fast_llm/data/dataset/memmap.py b/fast_llm/data/dataset/memmap.py index e2aeda077..ffb2bc6d1 100644 --- a/fast_llm/data/dataset/memmap.py +++ b/fast_llm/data/dataset/memmap.py @@ -57,6 +57,8 @@ def __del__(self): def get_document( self, index: int, begin: int = 0, end: int | None = None, parameters: SamplingParameters | None = None ) -> SampleType: + if end is None: + end = self._reader.get_document_size(index) return self._reader.get_document(index, begin, end) @property @@ -64,12 +66,11 @@ def name(self) -> str: return self._name def __len__(self) -> int: - return self._reader + return len(self._reader) - # TODO: ====== needed? ====== - # @property - # def num_tokens(self) -> int: - # return self._reader.num_tokens + @property + def num_tokens(self) -> int: + return self._reader.num_tokens def get_document_sizes(self) -> torch.Tensor: return self._reader.get_document_sizes() @@ -78,7 +79,9 @@ def get_document_size(self, index: int) -> int: return self._reader.get_document_size(index) @classmethod - def write_dataset(cls, path: pathlib.Path, documents: typing.Iterable[Sample], writer_class: type[MemmapWriter]): + def write_dataset( + cls, path: pathlib.Path, documents: typing.Iterable[Sample], writer_class: type[MemmapWriter] + ) -> MemmapIndexDatasetReaderConfig: # TODO: Match `writer_class` with `SampleType`? path.parent.mkdir(parents=True, exist_ok=True) with path.open("wb") as stream: @@ -98,3 +101,4 @@ def write_dataset(cls, path: pathlib.Path, documents: typing.Iterable[Sample], w # Write a pointer to the reader config. stream.seek(start) stream.write(config_offset.to_bytes(4, signed=False)) + return reader_config diff --git a/fast_llm/data/preparator/gpt_memmap/config.py b/fast_llm/data/preparator/gpt_memmap/config.py index c193cf942..7dd520ec3 100644 --- a/fast_llm/data/preparator/gpt_memmap/config.py +++ b/fast_llm/data/preparator/gpt_memmap/config.py @@ -65,7 +65,7 @@ def _validate(self): super()._validate() if self.has_loss_masking_span != self.rejected_spans_column is not None: raise ValueError(f"Both chosen and rejected loss masking spans must be specified if one is specified.") - if self.has_preference_spans == self.has_loss_masking_span: + if self.has_preference_spans and self.has_loss_masking_span: # TODO: ====== Still needed? ====== raise ValueError(f"Can not enable both loss masking and preference spans.") diff --git a/fast_llm/data/preparator/gpt_memmap/prepare.py b/fast_llm/data/preparator/gpt_memmap/prepare.py index 18ab2d787..06a4bd517 100644 --- a/fast_llm/data/preparator/gpt_memmap/prepare.py +++ b/fast_llm/data/preparator/gpt_memmap/prepare.py @@ -25,6 +25,7 @@ from fast_llm.data.dataset.memmap import MemmapDataset from fast_llm.data.preparator.config import DatasetPreparator from fast_llm.data.preparator.gpt_memmap.config import GPTMemmapDatasetPreparatorConfig, LanguageModelSourceConfig +from fast_llm.data.sample.abstract import MemmapIndexDatasetReaderConfig from fast_llm.data.sample.language_model import LanguageModelSample, LanguageModelWriter from fast_llm.data.sample.range import RangeSample from fast_llm.data.sample.token import TokenSample @@ -43,19 +44,21 @@ class GPTMemmapDatasetPreparator[ConfigType: GPTMemmapDatasetPreparatorConfig](D def __init__(self, config: ConfigType): super().__init__(config) - self._source_shema: LanguageModelSourceConfig = self._config.dataset.source_shema + self._source_schema: LanguageModelSourceConfig = self._config.dataset.source_schema - def _save_shard(self, args: tuple[int, datasets.Dataset]) -> MemmapDatasetConfig: + def _save_shard( + self, args: tuple[int, datasets.Dataset] + ) -> tuple[MemmapDatasetConfig, MemmapIndexDatasetReaderConfig]: shard_index, shard_dataset = args file_name = f"shard_{self._config.distributed.rank}_{shard_index}.fast_llm_dataset" - MemmapDataset.write_dataset( + reader_config = MemmapDataset.write_dataset( self._config.output_path / file_name, tqdm.tqdm((sample["sample"] for sample in shard_dataset), desc=f"Saving shard {shard_index}", unit="docs"), LanguageModelWriter, ) - return MemmapDatasetConfig.from_dict({"type": "memmap", "path": file_name}) + return MemmapDatasetConfig.from_dict({"type": "memmap", "path": file_name}), reader_config def _load_dataset(self) -> datasets.Dataset: dataset = datasets.load_dataset( @@ -173,7 +176,7 @@ def run(self) -> None: index=self._config.distributed.rank, ) - for column_name in self._source_shema.columns: + for column_name in self._source_schema.columns: if column_name not in dataset.column_names: raise ValueError(f"Dataset does not have field '{column_name}'.") @@ -196,42 +199,42 @@ def run(self) -> None: # Use multiprocessing to save each shard in parallel on all ranks with multiprocessing.Pool(processes=self._config.saving_workers) as pool: - dataset_configs = pool.map(self._save_shard, shards) + dataset_and_reader_configs = pool.map(self._save_shard, shards) - self.generate_config_yaml_for_sharded_dst(dataset_configs) + self.generate_config_yaml_for_sharded_dst(dataset_and_reader_configs) def _prepare_batch(self, batch: dict[str, list[typing.Any]]) -> dict[str, list[LanguageModelSample]]: # Gather values by sample using zip* - sample_column_values = zip(*(batch[column_name] for column_name in self._source_shema.columns)) + sample_column_values = zip(*(batch[column_name] for column_name in self._source_schema.columns)) # Convert to dicts using column names. sample_dicts = ( - {column_name: column_value for column_name, column_value in zip(self._source_shema.columns, sample_data)} + {column_name: column_value for column_name, column_value in zip(self._source_schema.columns, sample_data)} for sample_data in sample_column_values ) # Prepare each sample, wrap in dict for the `Dataset` interface return {"samples": [self._prepare_sample(sample_dict) for sample_dict in sample_dicts]} def _prepare_sample(self, sample: dict[str, typing.Any]) -> LanguageModelSample: - text = sample[self._source_shema.text_column] + text = sample[self._source_schema.text_column] all_spans = [] - if self._source_shema.has_loss_masking_span: + if self._source_schema.has_loss_masking_span: # TODO: ====== What is the input format? ====== # Spans are typically stored in the (begin, last) format. We convert to (begin, end) range format. loss_masking_spans = _sort_spans( (begin, last + 1) - for begin, last in np.array(sample[self._source_shema.loss_masking_spans_column], dtype=np.int32) + for begin, last in np.array(sample[self._source_schema.loss_masking_spans_column], dtype=np.int32) .reshape(-1, 2) .tolist() ) all_spans.extend(loss_masking_spans) - if self._source_shema.has_preference_spans: + if self._source_schema.has_preference_spans: # TODO: ===== Was `self._config.dataset.field` (bug?) ====== full_chosen_text = ( - text + sample[self._source_shema.chosen_spans_column] + self._tokenizer.tokenizer.eos_token + text + sample[self._source_schema.chosen_spans_column] + self._tokenizer.tokenizer.eos_token ) full_rejected_text = ( - self._tokenizer.tokenizer.bos_token + text + sample[self._source_shema.rejected_spans_column] + self._tokenizer.tokenizer.bos_token + text + sample[self._source_schema.rejected_spans_column] ) # compute chosen span chosen_spans = [[len(text), len(full_chosen_text)]] @@ -255,33 +258,37 @@ def _prepare_sample(self, sample: dict[str, typing.Any]) -> LanguageModelSample: return LanguageModelSample( TokenSample(tokens, [sample_size]), - RangeSample(loss_masking_spans, sample_size) if self._source_shema.has_loss_masking_span else None, - RangeSample(chosen_spans, sample_size) if self._source_shema.has_preference_spans else None, - RangeSample(rejected_span, sample_size) if self._source_shema.has_preference_spans else None, + RangeSample(loss_masking_spans, sample_size) if self._source_schema.has_loss_masking_span else None, + RangeSample(chosen_spans, sample_size) if self._source_schema.has_preference_spans else None, + RangeSample(rejected_span, sample_size) if self._source_schema.has_preference_spans else None, ) - def generate_config_yaml_for_sharded_dst(self, dataset_configs: list[MemmapDatasetConfig]) -> None: + def generate_config_yaml_for_sharded_dst( + self, dataset_and_reader_configs: list[tuple[MemmapDatasetConfig, MemmapIndexDatasetReaderConfig]] + ) -> None: # Gather dataset_dicts from all ranks to rank 0 if self._config.distributed.world_size > 1: if self._config.distributed.rank == 0: - all_dataset_configs = [None] * self._config.distributed.world_size - torch.distributed.gather_object(dataset_configs, all_dataset_configs, dst=0) - dataset_configs = [item for sublist in all_dataset_configs for item in sublist] + all_dataset_and_reader_configs = [None] * self._config.distributed.world_size + torch.distributed.gather_object(dataset_and_reader_configs, all_dataset_and_reader_configs, dst=0) + dataset_and_reader_configs = [item for sublist in all_dataset_and_reader_configs for item in sublist] else: - torch.distributed.gather_object(dataset_configs, [], dst=0) + torch.distributed.gather_object(dataset_and_reader_configs, [], dst=0) if self._config.distributed.rank == 0: # Create the config file(s) on rank 0 + dataset_configs, reader_configs = zip(*dataset_and_reader_configs) if self._config.splits: for split_name, split_config in self._split_and_blend_dataset_configs( - dataset_configs, self._config.splits, self._config.output_path + dataset_configs, reader_configs, self._config.splits, self._config.output_path ).items(): self._save_dataset_config( split_config, self._config.output_path / f"fast_llm_config_{split_name}.yaml" ) else: self._save_dataset_config( - self._blend_dataset_configs(dataset_configs), self._config.output_path / f"fast_llm_config.yaml" + self._blend_dataset_configs(dataset_configs, reader_configs), + self._config.output_path / f"fast_llm_config.yaml", ) # Save metadata on rank 0 @@ -304,7 +311,9 @@ def _save_dataset_config( @classmethod def _blend_dataset_configs( - cls, dataset_configs: list[MemmapDatasetConfig[_sample_type]] + cls, + dataset_configs: list[MemmapDatasetConfig[_sample_type]], + reader_configs: list[MemmapIndexDatasetReaderConfig], ) -> IndexedDatasetConfig[_sample_type]: if len(dataset_configs) == 1: return dataset_configs[0] @@ -312,7 +321,7 @@ def _blend_dataset_configs( { "type": "blended", "datasets": dataset_configs, - "weights": [dataset_config.num_tokens for dataset_config in dataset_configs], + "weights": [reader_config.num_tokens for reader_config in reader_configs], } ) @@ -320,12 +329,13 @@ def _blend_dataset_configs( def _split_and_blend_dataset_configs( cls, dataset_configs: list[MemmapDatasetConfig[_sample_type]], + reader_configs: list[MemmapIndexDatasetReaderConfig], splits: dict[str, int | float], output_path: pathlib.Path, ) -> dict[str, SampledDatasetConfig[_sample_type]]: # TODO: ====== Missing `num_tokens`, `num_documents`. ====== split_cumsum = padded_cumsum(normalize_probabilities(list(splits.values()), return_array=True)).tolist() - dataset_sizes = [dataset_config.num_tokens for dataset_config in dataset_configs] + dataset_sizes = [reader_config.num_tokens for reader_config in reader_configs] dataset_probabilities = normalize_probabilities(dataset_sizes) dataset_cumsums = padded_cumsum(dataset_probabilities).tolist() dataset_splits = {} @@ -333,7 +343,9 @@ def _split_and_blend_dataset_configs( for split_index, split_name in enumerate(splits): datasets_in_split = [] dataset_tokens_in_split = [] - for dataset_index, dataset_config in enumerate(dataset_configs): + for dataset_index, (dataset_config, reader_config) in enumerate( + zip(dataset_configs, reader_configs, strict=True) + ): split_begin_in_dataset = max( (split_cumsum[split_index] - dataset_cumsums[dataset_index]) / dataset_probabilities[dataset_index], @@ -353,17 +365,17 @@ def _split_and_blend_dataset_configs( # TODO: Somehow getting a segfault when merging two lines below (numpy bug?). dataset = dataset_config.to_copy({"path": output_path / dataset_config.path}).build() sizes_cumsum = dataset.get_document_sizes().numpy().cumsum() - Assert.eq(sizes_cumsum[-1], dataset_config.num_tokens) - begin_index = _get_nearest_split(sizes_cumsum, split_begin_in_dataset * dataset_config.num_tokens) - end_index = _get_nearest_split(sizes_cumsum, split_end_in_dataset * dataset_config.num_tokens) + Assert.eq(sizes_cumsum[-1], reader_config.num_tokens) + begin_index = _get_nearest_split(sizes_cumsum, split_begin_in_dataset * reader_config.num_tokens) + end_index = _get_nearest_split(sizes_cumsum, split_end_in_dataset * reader_config.num_tokens) if end_index > begin_index: datasets_in_split.append( DatasetSliceConfig[cls._sample_type].from_dict( { "type": "slice", "dataset": dataset_configs[dataset_index], - "begin": begin_index / dataset_config.num_documents, - "end": end_index / dataset_config.num_documents, + "begin": begin_index / len(reader_config), + "end": end_index / len(reader_config), } ) ) diff --git a/fast_llm/data/sample/abstract.py b/fast_llm/data/sample/abstract.py index f122100f9..9afc6124c 100644 --- a/fast_llm/data/sample/abstract.py +++ b/fast_llm/data/sample/abstract.py @@ -71,7 +71,7 @@ def get_reader(self, buffer: memoryview) -> "MemmapReader|None": @property def expected_buffer_size(self) -> int: """ - The expected buffer size in bytes. Used for self-validation. + The expected buffer size in bytes, including header and footer. Used for self-validation. """ raise NotImplementedError() @@ -98,15 +98,33 @@ class MemmapReaderConfig(MemmapReaderBaseConfig): Configuration for a standard memmap reader. """ + # Data location in the file. begin: int = Field() end: int = Field() + # Constant strings for alignment safety. + header: typing.ClassVar[bytes] + footer: typing.ClassVar[bytes] @property def reader_class(self) -> "type[MemmapReader]": raise NotImplementedError() def get_reader(self, buffer: memoryview) -> "MemmapReader": - return self.reader_class(self, buffer[self.begin : self.end]) + return self.reader_class(self, buffer) + + @property + def expected_buffer_size(self) -> int: + """ + The expected buffer size in bytes, including header and footer. Used for self-validation. + """ + return self._expected_buffer_size + len(self.header) + len(self.footer) + + @property + def _expected_buffer_size(self) -> int: + """ + The expected buffer size in bytes, excluding header and footer. Used for self-validation. + """ + raise NotImplementedError() @property def writer_class(self) -> "type[MemmapWriter]": @@ -117,7 +135,6 @@ def get_writer(self, stream: io.BufferedWriter) -> "MemmapWriter": def _validate(self): super()._validate() - print("AAAAA", self.__class__.__name__, self.begin, self.end, self.expected_buffer_size) Assert.eq(self.end - self.begin, self.expected_buffer_size) @@ -128,6 +145,15 @@ class MemmapIndexDatasetReaderConfig(MemmapReaderConfig): consisting of a list of documents of known lengths. """ + @abc.abstractmethod + def __len__(self) -> int: + pass + + @property + @abc.abstractmethod + def num_tokens(self) -> int: + pass + @property def reader_class(self) -> "type[MemmapIndexedDatasetReader]": raise NotImplementedError() @@ -136,13 +162,17 @@ def get_reader( self, buffer: memoryview, ) -> "MemmapIndexedDatasetReader": - return self.reader_class(self, buffer[self.begin : self.end]) + return self.reader_class(self, buffer) -class MemmapReader[ConfigType: MemmapReaderBaseConfig](Configurable[ConfigType]): +class MemmapReader[ConfigType: MemmapReaderConfig](Configurable[ConfigType]): def __init__(self, config: ConfigType, buffer: memoryview): super().__init__(config) - self._buffer = buffer[self._config.begin : self._config.end] + buffer_begin = self._config.begin + len(self._config.header) + buffer_end = self._config.end - len(self._config.footer) + Assert.eq(buffer[self._config.begin : buffer_begin].tobytes(), self._config.header) + Assert.eq(buffer[buffer_end : self._config.end].tobytes(), self._config.footer) + self._buffer = buffer[buffer_begin:buffer_end] @abc.abstractmethod def get_document(self, index: int, begin: int, end: int) -> Sample: @@ -150,6 +180,13 @@ def get_document(self, index: int, begin: int, end: int) -> Sample: class MemmapIndexedDatasetReader[ConfigType: MemmapIndexDatasetReaderConfig](MemmapReader[ConfigType]): + def __len__(self) -> int: + return len(self._config) + + @property + def num_tokens(self) -> int: + return self._config.num_tokens + @abc.abstractmethod def get_document_sizes(self) -> "torch.Tensor": pass @@ -159,7 +196,7 @@ def get_document_size(self, index: int) -> int: pass -class MemmapWriter: +class MemmapWriter(abc.ABC): def __init__(self, stream: io.BufferedWriter | pathlib.Path): self._owns_stream = isinstance(stream, pathlib.Path) if self._owns_stream: @@ -168,16 +205,23 @@ def __init__(self, stream: io.BufferedWriter | pathlib.Path): def __enter__(self): self._begin = self._stream.tell() + self._stream.write(self._get_config_class().header) return self def write(self, document: Sample): assert hasattr(self, "_begin") and not hasattr(self, "_end") def __exit__(self, exc_type, exc_val, exc_tb): + self._stream.write(self._get_config_class().footer) self._end = self._stream.tell() if self._owns_stream: self._stream.close() + @classmethod + @abc.abstractmethod + def _get_config_class(cls) -> type[MemmapReaderConfig]: + pass + def get_config(self, offset: int = 0) -> MemmapReaderConfig: assert hasattr(self, "_end") return self._get_config(self._begin + offset, self._end + offset) diff --git a/fast_llm/data/sample/language_model.py b/fast_llm/data/sample/language_model.py index 3d6964b30..d6f737c7b 100644 --- a/fast_llm/data/sample/language_model.py +++ b/fast_llm/data/sample/language_model.py @@ -125,12 +125,21 @@ def _crop_optional[T: Sample | Batch](sample_or_batch: T, begin: int, end: int) @config_class(dynamic_type={MemmapReaderBaseConfig: "language_model"}) class LanguageModelReaderConfig(MemmapIndexDatasetReaderConfig): _abstract = False + header: typing.ClassVar[bytes] = b"lm begin" + footer: typing.ClassVar[bytes] = b"lm end" tokens: TokenReaderConfig = Field() # Using dynamic type for optional readers for enabling/disabling loss_masking_spans: MemmapReaderBaseConfig = Field() chosen_spans: MemmapReaderBaseConfig = Field() rejected_spans: MemmapReaderBaseConfig = Field() + def __len__(self) -> int: + return len(self.tokens) + + @property + def num_tokens(self) -> int: + return self.tokens.num_tokens + @property def reader_class(self) -> "type[LanguageModelReader]": return LanguageModelReader @@ -140,7 +149,7 @@ def writer_class(self) -> "type[LanguageModelWriter]": return LanguageModelWriter @property - def expected_buffer_size(self) -> int: + def _expected_buffer_size(self) -> int: return ( self.tokens.expected_buffer_size + self.loss_masking_spans.expected_buffer_size @@ -155,13 +164,19 @@ def __init__(self, config: ConfigType, buffer: memoryview): # Using `buffer` and not `self._buffer` because nested offsets (`begin`, `end`) are global. self._tokens = self._config.tokens.get_reader(buffer) self._loss_masking_spans = self._config.loss_masking_spans.get_reader(buffer) - self._preference_spans = self._config.preference_spans.get_reader(buffer) + self._chosen_spans = self._config.chosen_spans.get_reader(buffer) + self._rejected_spans = self._config.rejected_spans.get_reader(buffer) + + @property + def num_tokens(self) -> int: + return self._config.tokens.num_tokens def get_document(self, index: int, begin: int, end: int) -> Sample: return LanguageModelSample( self._tokens.get_document(index, begin, end), - self._loss_masking_spans.get_document(index, begin, end), - self._preference_spans.get_document(index, begin, end), + None if self._loss_masking_spans is None else self._loss_masking_spans.get_document(index, begin, end), + None if self._chosen_spans is None else self._chosen_spans.get_document(index, begin, end), + None if self._rejected_spans is None else self._rejected_spans.get_document(index, begin, end), ) def get_document_sizes(self) -> torch.Tensor: @@ -248,8 +263,12 @@ def __exit__(self, exc_type, exc_val, exc_tb): self._directory.cleanup() super().__exit__(exc_type, exc_val, exc_tb) + @classmethod + def _get_config_class(cls) -> type[LanguageModelReaderConfig]: + return LanguageModelReaderConfig + def _get_config(self, begin: int, end: int | None): - tokens = self._token_writer.get_config(begin) + tokens = self._token_writer.get_config(begin + len(LanguageModelReaderConfig.header)) offset = tokens.end if self._has_loss_masking_spans: loss_masking_spans = self._loss_masking_span_writer.get_config(offset) @@ -266,7 +285,7 @@ def _get_config(self, begin: int, end: int | None): rejected_spans = NullReaderConfig() if end is None: - end = offset + end = offset + len(LanguageModelReaderConfig.footer) return LanguageModelReaderConfig( begin=begin, diff --git a/fast_llm/data/sample/range.py b/fast_llm/data/sample/range.py index 88dd1352d..92d5ce7fc 100644 --- a/fast_llm/data/sample/range.py +++ b/fast_llm/data/sample/range.py @@ -63,6 +63,8 @@ def to_samples(self) -> list[RangeSample]: @config_class(dynamic_type={MemmapReaderBaseConfig: "range"}) class RangeReaderConfig(MemmapReaderConfig): _abstract = False + header: typing.ClassVar[bytes] = b"range begin" + footer: typing.ClassVar[bytes] = b"range end" num_documents: int = Field() num_ranges: int = Field() @@ -75,8 +77,8 @@ def writer_class(self) -> "type[RangeWriter]": return RangeWriter @property - def expected_buffer_size(self) -> int: - return (self.num_ranges + 1) * torch.uint32.itemsize * 2 + (self.num_documents + 1) * torch.uint32.itemsize + def _expected_buffer_size(self) -> int: + return self.num_ranges * torch.int32.itemsize * 2 + (self.num_documents + 1) * torch.int32.itemsize class RangeReader[ConfigType: RangeReaderConfig](MemmapReader[ConfigType]): @@ -84,17 +86,17 @@ def __init__(self, config: ConfigType, buffer: memoryview): super().__init__(config, buffer) self._ranges = torch.frombuffer( self._buffer, - dtype=torch.uint32, - count=self._config.num_ranges, + dtype=torch.int32, + count=self._config.num_ranges * 2, ).reshape(-1, 2) self._count_cumsums = torch.frombuffer( self._buffer, - dtype=torch.uint32, + dtype=torch.int32, count=self._config.num_documents + 1, offset=self._ranges.nbytes, ) - def get(self, index: int, begin: int, end: int) -> RangeSample: + def get_document(self, index: int, begin: int, end: int) -> Sample: sample_size = end - begin cropped_ranges = ( (max(begin_ - begin, 0), min(end_ - begin, sample_size)) @@ -110,15 +112,18 @@ def __enter__(self): return self def write(self, document: RangeSample): - # ====== TODO: Make sure input uses end = 1 past last index (currently use last index) ====== super().write(document) - self._stream.write(np.array(document.ranges, dtype=np.uint32).tobytes(order="C")) + self._stream.write(np.array(document.ranges, dtype=np.int32).tobytes(order="C")) self._count_cumsum.append(self._count_cumsum[-1] + len(document.ranges)) def __exit__(self, exc_type, exc_val, exc_tb): - self._stream.write(np.array(self._count_cumsum, dtype=np.uint32).tobytes(order="C")) + self._stream.write(np.array(self._count_cumsum, dtype=np.int32).tobytes(order="C")) super().__exit__(exc_type, exc_val, exc_tb) + @classmethod + def _get_config_class(cls) -> type[RangeReaderConfig]: + return RangeReaderConfig + def _get_config(self, begin: int, end: int): return RangeReaderConfig( begin=begin, diff --git a/fast_llm/data/sample/token.py b/fast_llm/data/sample/token.py index 98ee9a2a1..0e57209c5 100644 --- a/fast_llm/data/sample/token.py +++ b/fast_llm/data/sample/token.py @@ -91,6 +91,11 @@ class TokenReaderConfig(MemmapReaderConfig): num_documents: int = Field() num_tokens: int = Field() data_type: DataType = Field() + header: typing.ClassVar[bytes] = b"token begin" + footer: typing.ClassVar[bytes] = b"token end" + + def __len__(self) -> int: + return self.num_documents @property def reader_class(self) -> "type[TokenReader]": @@ -101,8 +106,8 @@ def writer_class(self) -> "type[TokenWriter]": return TokenWriter @property - def expected_buffer_size(self) -> int: - return self.num_tokens * self.data_type.torch.itemsize + (self.num_documents + 1) * torch.uint64.itemsize + def _expected_buffer_size(self) -> int: + return self.num_tokens * self.data_type.torch.itemsize + (self.num_documents + 1) * torch.int64.itemsize class TokenReader[ConfigType: TokenReaderConfig](MemmapIndexedDatasetReader[ConfigType]): @@ -114,12 +119,13 @@ def __init__(self, config: ConfigType, buffer: memoryview): count=self._config.num_tokens, ) self._size_cumsums = torch.frombuffer( - self._buffer, dtype=torch.uint64, count=self._config.num_documents + 1, offset=self._tokens.nbytes + self._buffer, dtype=torch.int64, count=self._config.num_documents + 1, offset=self._tokens.nbytes ) def get_document(self, index: int, begin: int, end: int) -> Sample: begin_ = self._size_cumsums[index].item() - return TokenSample(torch.from_numpy(self._tokens[begin_ + begin : begin_ + end]), [end - begin]) + # Torch doesn't support type promotion between signed and unsigned types, so we convert here to avoid issues. + return TokenSample(self._tokens[begin_ + begin : begin_ + end].to(torch.int64), [end - begin]) def get_document_sizes(self) -> torch.Tensor: return self._size_cumsums[1:] - self._size_cumsums[:-1] @@ -146,9 +152,13 @@ def write(self, document: TokenSample): self._size_cumsum.append(self._size_cumsum[-1] + len(document.tokens)) def __exit__(self, exc_type, exc_val, exc_tb): - self._stream.write(np.array(self._size_cumsum, dtype=np.uint64).tobytes(order="C")) + self._stream.write(np.array(self._size_cumsum, dtype=np.int64).tobytes(order="C")) super().__exit__(exc_type, exc_val, exc_tb) + @classmethod + def _get_config_class(cls) -> type[TokenReaderConfig]: + return TokenReaderConfig + def _get_config(self, begin: int, end: int): return TokenReaderConfig( begin=begin, diff --git a/tests/data/common.py b/tests/data/common.py index e6ab8a265..7053666b8 100644 --- a/tests/data/common.py +++ b/tests/data/common.py @@ -111,7 +111,7 @@ def get_test_data_and_compare_samples( for phase, samples in samples_per_dataset.items() } for phase, expected_samples_ in expected_samples.items(): - Assert.all_equal(tokens[phase].to(torch.int64), expected_samples_) + Assert.all_equal(tokens[phase], expected_samples_) return data @@ -130,7 +130,7 @@ def compare_indexed_dataset( sizes[: min(len(dataset), 100)], ) for i, expected_sample in expected_samples.items(): - Assert.all_equal(dataset.get_document(i).tokens.tokens, np.array(expected_sample, dtype=np.int64)) + Assert.all_equal(dataset.get_document(i).tokens.tokens, np.array(expected_sample)) if loss_masking_spans: for i, loss_masking_span in loss_masking_spans.items(): print(i) @@ -147,9 +147,7 @@ def compare_indexed_dataset( def compare_sampled_dataset(sampled: SampledDataset, expected_samples: list[list[int] | np.ndarray]) -> None: Assert.eq(len(sampled), len(expected_samples)) - Assert.all_equal( - torch.stack([sampled[i].tokens.tokens for i in range(len(expected_samples))]).to(torch.int64), expected_samples - ) + Assert.all_equal(torch.stack([sampled[i].tokens.tokens for i in range(len(expected_samples))]), expected_samples) def validate_indexed_dataset_sampling(sampled: SampledIndexedDataset, expected_samples: list[list[int]] | None = None): @@ -210,6 +208,9 @@ class MockGPTMemmapDatasetConfig(IndexedDatasetConfig): def build(self) -> "IndexedDataset": return MockMemmapDataset(self) + def __len__(self) -> int: + return self.num_documents + @property def num_tokens(self) -> int: return self.num_documents * self.num_tokens_per_document @@ -224,7 +225,11 @@ def name(self) -> str: return "mock_memmap" def __len__(self) -> int: - return self._config.num_documents + return len(self._config) + + @property + def num_tokens(self) -> int: + return self._config.num_tokens def get_document_sizes(self) -> torch.Tensor: return torch.full([self._config.num_documents], self._config.num_tokens_per_document, dtype=torch.int64) diff --git a/tests/data/test_blending.py b/tests/data/test_blending.py index 49eceee0b..b2b2f0117 100644 --- a/tests/data/test_blending.py +++ b/tests/data/test_blending.py @@ -15,11 +15,11 @@ from tests.utils.dataset import get_test_dataset from tests.utils.global_variables import DATASET_CACHE, DATASET_PATH -_DATASET_PREFIX_MIX_1 = DATASET_CACHE / "blended_mix_1" / "dataset" +_DATASET_PATH_MIX_1 = DATASET_CACHE / "blended_mix_1" / "dataset" def _get_test_dataset_mix_1(): - return get_test_dataset(prefix=_DATASET_PREFIX_MIX_1, seed=2345) + return get_test_dataset(_DATASET_PATH_MIX_1, seed=2345) def _get_blending_alt(probs: list[float], num_samples: int) -> tuple[np.ndarray, np.ndarray]: @@ -119,7 +119,7 @@ def test_gpt_blended(): "type": "blended", "datasets": [ {"type": "memmap", "path": DATASET_PATH}, - {"type": "memmap", "path": _DATASET_PREFIX_MIX_1}, + {"type": "memmap", "path": _DATASET_PATH_MIX_1}, ], "weights": [0.75, 0.25], }, @@ -138,7 +138,7 @@ def test_gpt_blended_data(): "type": "blended", "datasets": [ {"type": "memmap", "path": DATASET_PATH}, - {"type": "memmap", "path": _DATASET_PREFIX_MIX_1}, + {"type": "memmap", "path": _DATASET_PATH_MIX_1}, ], "weights": [0.75, 0.25], } diff --git a/tests/data/test_memmap.py b/tests/data/test_memmap.py index 419b67903..b11f84d9c 100644 --- a/tests/data/test_memmap.py +++ b/tests/data/test_memmap.py @@ -27,8 +27,8 @@ def test_gpt_memmap(cache_directory): MEMMAP_DATASET_SPANS = { 9: [], - 10: [(0, 2), (2, 7), (7, 10)], - 13: [(0, 2)], + 10: [(0, 1), (2, 6), (7, 9)], + 13: [(0, 1)], 15: [], } diff --git a/tests/data/test_prepare_gpt_memmap.py b/tests/data/test_prepare_gpt_memmap.py index 1608bb48c..9647264e7 100644 --- a/tests/data/test_prepare_gpt_memmap.py +++ b/tests/data/test_prepare_gpt_memmap.py @@ -46,8 +46,8 @@ def test_write_memmap_dataset(dtype): Assert.all_equal(dataset.get_document(i).tokens.tokens, document.tokens.tokens.to(torch.int64)) -def _generate_valid_span(max_seq_length): - return np.sort(np.random.choice(np.arange(0, max_seq_length - 1), size=2, replace=False)).tolist() +def _generate_valid_span(max_seq_length) -> tuple[int, int]: + return tuple(np.sort(np.random.choice(np.arange(0, max_seq_length - 1), size=2, replace=False)).tolist()) @pytest.mark.parametrize("dtype", MEMMAP_DTYPES.values()) @@ -56,8 +56,8 @@ def test_write_memmap_preference_dataset(dtype): LanguageModelSample( TokenSample(torch.from_numpy(np.random.randint(1000, size=100).astype(dtype))), None, - RangeSample(_generate_valid_span(100), 100), - RangeSample(_generate_valid_span(100), 100), + RangeSample([_generate_valid_span(100)], 100), + RangeSample([_generate_valid_span(100)], 100), ) for _ in range(50) ] @@ -128,6 +128,7 @@ def test_split_dataset(): dataset_config_0 = IndexedDatasetConfig[LanguageModelSample].from_dict(DATASET_DICT_0.copy()) config = GPTMemmapDatasetPreparator._split_and_blend_dataset_configs( [dataset_config_0], + [dataset_config_0], # Mock reader config. {"training": 3, "validation": 1}, pathlib.Path("."), ) @@ -157,6 +158,7 @@ def test_split_datasets_0(): dataset_config_1 = IndexedDatasetConfig[LanguageModelSample].from_dict(DATASET_DICT_1.copy()) config = GPTMemmapDatasetPreparator._split_and_blend_dataset_configs( [dataset_config_0, dataset_config_1], + [dataset_config_0, dataset_config_1], # Mock reader configs. {"training": 1, "validation": 1}, pathlib.Path("."), ) @@ -175,7 +177,10 @@ def test_split_datasets_1(): dataset_config_0 = IndexedDatasetConfig[LanguageModelSample].from_dict(DATASET_DICT_0.copy()) dataset_config_1 = IndexedDatasetConfig[LanguageModelSample].from_dict(DATASET_DICT_1.copy()) config = GPTMemmapDatasetPreparator._split_and_blend_dataset_configs( - [dataset_config_0, dataset_config_1], {"training": 3, "validation": 1}, pathlib.Path(".") + [dataset_config_0, dataset_config_1], + [dataset_config_0, dataset_config_1], # Mock reader configs. + {"training": 3, "validation": 1}, + pathlib.Path("."), ) config = {key: value.to_dict() for key, value in config.items()} From 34939e930b2d2e1bd3d636c05d2f91e303bbafa1 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Wed, 29 Oct 2025 19:05:31 -0400 Subject: [PATCH 03/12] fixes --- fast_llm/data/dataset/config.py | 1 - fast_llm/data/dataset/gpt/legacy_memmap.py | 14 ++- fast_llm/data/dataset/memmap.py | 11 +-- fast_llm/data/dataset/sampled.py | 1 - fast_llm/data/preparator/gpt_memmap/config.py | 17 +--- .../data/preparator/gpt_memmap/prepare.py | 3 +- fast_llm/data/sample/abstract.py | 6 +- fast_llm/data/sample/language_model.py | 1 - fast_llm/data/sample/token.py | 1 - fast_llm/functional/dpo.py | 2 +- tests/data/test_prepare_gpt_memmap.py | 3 +- tests/models/test_match_megatron.py | 89 +++++++++++++++++-- tests/utils/dataset.py | 56 +++++++----- tests/utils/global_variables.py | 6 +- tests/utils/model_configs.py | 1 - 15 files changed, 149 insertions(+), 63 deletions(-) diff --git a/fast_llm/data/dataset/config.py b/fast_llm/data/dataset/config.py index f60decd81..7611b4a31 100644 --- a/fast_llm/data/dataset/config.py +++ b/fast_llm/data/dataset/config.py @@ -109,7 +109,6 @@ class SampledDatasetConfig[SampleType: Sample](DatasetConfig[SampleType]): """ def build_and_sample(self, sampling: SamplingData) -> SampledDataset[SampleType]: - # TODO: ====== `SamplingData` contains more than needed (ex. `num_samples`) raise NotImplementedError() diff --git a/fast_llm/data/dataset/gpt/legacy_memmap.py b/fast_llm/data/dataset/gpt/legacy_memmap.py index d8c63e9f9..2a23e378b 100644 --- a/fast_llm/data/dataset/gpt/legacy_memmap.py +++ b/fast_llm/data/dataset/gpt/legacy_memmap.py @@ -6,12 +6,24 @@ from fast_llm.data.dataset.gpt.config import GPTSamplingParameters from fast_llm.data.dataset.indexed import IndexedDataset -from fast_llm.data.preparator.gpt_memmap.config import MEMMAP_DTYPES, MEMMAP_INDEX_HEADER from fast_llm.data.sample.language_model import LanguageModelSample from fast_llm.data.sample.range import RangeSample from fast_llm.data.sample.token import TokenSample +from fast_llm.engine.config_utils.data_type import DataType from fast_llm.utils import Assert, div +MEMMAP_DTYPES = { + 1: DataType.uint8, + 2: DataType.int8, + 3: DataType.int16, + 4: DataType.int32, + 5: DataType.int64, + 6: DataType.float32, + 7: DataType.float64, + 8: DataType.uint16, +} +MEMMAP_INDEX_HEADER = b"MMIDIDX\x00\x00" + class LegacyMemmapDataset[SampleType: LanguageModelSample](IndexedDataset[SampleType]): """ diff --git a/fast_llm/data/dataset/memmap.py b/fast_llm/data/dataset/memmap.py index ffb2bc6d1..e51dfb40d 100644 --- a/fast_llm/data/dataset/memmap.py +++ b/fast_llm/data/dataset/memmap.py @@ -40,14 +40,15 @@ def _init(self, name: str, path: pathlib.Path | str) -> None: ) self._memmap = np.memmap(self._path, mode="r") - # TODO: ===== Check num_documents, num_tokens ====== self._reader = reader_config.get_reader(memoryview(self._memmap)) - def __getstate__(self) -> tuple[str, pathlib.Path]: - return (self._name, self._path) + def __getstate__(self) -> tuple[str, pathlib.Path, MemmapIndexDatasetReaderConfig]: + # We pass the reader config to force its import in data loader workers. + return self._name, self._path, self._reader.config - def __setstate__(self, state: tuple[str, pathlib.Path]): - self._init(*state) + def __setstate__(self, state: tuple[str, pathlib.Path, MemmapIndexDatasetReaderConfig]): + name, path, _ = state + self._init(name, path) def __del__(self): if hasattr(self, "_memmap"): diff --git a/fast_llm/data/dataset/sampled.py b/fast_llm/data/dataset/sampled.py index 46a518cd0..d51a68746 100644 --- a/fast_llm/data/dataset/sampled.py +++ b/fast_llm/data/dataset/sampled.py @@ -414,7 +414,6 @@ def __getitem__(self, index: int) -> SampleType: document_sampling_index += 1 token_count += document_size - # TODO: ====== Better way to get the class method? ====== return documents[0].from_documents(documents) @property diff --git a/fast_llm/data/preparator/gpt_memmap/config.py b/fast_llm/data/preparator/gpt_memmap/config.py index 7dd520ec3..a54465080 100644 --- a/fast_llm/data/preparator/gpt_memmap/config.py +++ b/fast_llm/data/preparator/gpt_memmap/config.py @@ -12,18 +12,6 @@ if typing.TYPE_CHECKING: from fast_llm.data.preparator.gpt_memmap.prepare import GPTMemmapDatasetPreparator -MEMMAP_DTYPES = { - 1: DataType.uint8, - 2: DataType.int8, - 3: DataType.int16, - 4: DataType.int32, - 5: DataType.int64, - 6: DataType.float32, - 7: DataType.float64, - 8: DataType.uint16, -} -MEMMAP_DTYPES_INV = {y: x for x, y in MEMMAP_DTYPES.items()} -MEMMAP_INDEX_HEADER = b"MMIDIDX\x00\x00" @config_class() @@ -66,7 +54,6 @@ def _validate(self): if self.has_loss_masking_span != self.rejected_spans_column is not None: raise ValueError(f"Both chosen and rejected loss masking spans must be specified if one is specified.") if self.has_preference_spans and self.has_loss_masking_span: - # TODO: ====== Still needed? ====== raise ValueError(f"Can not enable both loss masking and preference spans.") @@ -204,10 +191,8 @@ class GPTMemmapDatasetPreparatorConfig(DatasetPreparatorConfig): ) def _validate(self) -> None: - assert self.tokenizer.path is not None - if self.dataset.data_type is not None: - Assert.incl(DataType.from_numpy(self.dataset.data_type.numpy), MEMMAP_DTYPES_INV) super()._validate() + assert self.tokenizer.path is not None @classmethod def get_dataset_preparator_class(cls) -> type["GPTMemmapDatasetPreparator"]: diff --git a/fast_llm/data/preparator/gpt_memmap/prepare.py b/fast_llm/data/preparator/gpt_memmap/prepare.py index 06a4bd517..d3d15fa64 100644 --- a/fast_llm/data/preparator/gpt_memmap/prepare.py +++ b/fast_llm/data/preparator/gpt_memmap/prepare.py @@ -218,7 +218,7 @@ def _prepare_sample(self, sample: dict[str, typing.Any]) -> LanguageModelSample: text = sample[self._source_schema.text_column] all_spans = [] if self._source_schema.has_loss_masking_span: - # TODO: ====== What is the input format? ====== + # TODO: ====== What is the exact input format? ====== # Spans are typically stored in the (begin, last) format. We convert to (begin, end) range format. loss_masking_spans = _sort_spans( (begin, last + 1) @@ -333,7 +333,6 @@ def _split_and_blend_dataset_configs( splits: dict[str, int | float], output_path: pathlib.Path, ) -> dict[str, SampledDatasetConfig[_sample_type]]: - # TODO: ====== Missing `num_tokens`, `num_documents`. ====== split_cumsum = padded_cumsum(normalize_probabilities(list(splits.values()), return_array=True)).tolist() dataset_sizes = [reader_config.num_tokens for reader_config in reader_configs] dataset_probabilities = normalize_probabilities(dataset_sizes) diff --git a/fast_llm/data/sample/abstract.py b/fast_llm/data/sample/abstract.py index 9afc6124c..0b2e324c3 100644 --- a/fast_llm/data/sample/abstract.py +++ b/fast_llm/data/sample/abstract.py @@ -145,14 +145,12 @@ class MemmapIndexDatasetReaderConfig(MemmapReaderConfig): consisting of a list of documents of known lengths. """ - @abc.abstractmethod def __len__(self) -> int: - pass + raise NotImplementedError() @property - @abc.abstractmethod def num_tokens(self) -> int: - pass + raise NotImplementedError() @property def reader_class(self) -> "type[MemmapIndexedDatasetReader]": diff --git a/fast_llm/data/sample/language_model.py b/fast_llm/data/sample/language_model.py index d6f737c7b..77cc6e8a2 100644 --- a/fast_llm/data/sample/language_model.py +++ b/fast_llm/data/sample/language_model.py @@ -205,7 +205,6 @@ def __enter__(self): return self def write(self, document: LanguageModelSample): - # ====== TODO: Make sure input uses end = 1 past last index (currently use last index) ====== super().write(document) # Write tokens. self._token_writer.write(document.tokens) diff --git a/fast_llm/data/sample/token.py b/fast_llm/data/sample/token.py index 0e57209c5..ae190658f 100644 --- a/fast_llm/data/sample/token.py +++ b/fast_llm/data/sample/token.py @@ -142,7 +142,6 @@ def __enter__(self): return self def write(self, document: TokenSample): - # ====== TODO: Make sure input uses end = 1 past last index (currently use last index) ====== super().write(document) if self._data_type is None: self._data_type = document.tokens.dtype diff --git a/fast_llm/functional/dpo.py b/fast_llm/functional/dpo.py index 7ab0b9ff6..c5ae48eba 100644 --- a/fast_llm/functional/dpo.py +++ b/fast_llm/functional/dpo.py @@ -37,7 +37,7 @@ def compute_dpo_loss( reference_log_probabilities, chosen_spans ) - _get_target_log_probability_for_spans(reference_log_probabilities, rejected_spans) - # TODO: ====== Shouldn't the sigmoid be computed independently for each document? + # TODO: ====== Shouldn't the sigmoid be computed independently for each document? ======= losses = -torch.nn.functional.logsigmoid(beta * (policy_log_ratios - reference_log_ratios)) if grad_output is None: diff --git a/tests/data/test_prepare_gpt_memmap.py b/tests/data/test_prepare_gpt_memmap.py index 9647264e7..09a91d6a8 100644 --- a/tests/data/test_prepare_gpt_memmap.py +++ b/tests/data/test_prepare_gpt_memmap.py @@ -8,8 +8,9 @@ from fast_llm.data.dataset.config import IndexedDatasetConfig from fast_llm.data.dataset.gpt.config import GPTSamplingParameters +from fast_llm.data.dataset.gpt.legacy_memmap import MEMMAP_DTYPES from fast_llm.data.dataset.memmap import MemmapDataset -from fast_llm.data.preparator.gpt_memmap.config import MEMMAP_DTYPES, GPTMemmapDatasetPreparatorConfig +from fast_llm.data.preparator.gpt_memmap.config import GPTMemmapDatasetPreparatorConfig from fast_llm.data.preparator.gpt_memmap.prepare import GPTMemmapDatasetPreparator from fast_llm.data.sample.language_model import LanguageModelSample, LanguageModelWriter from fast_llm.data.sample.range import RangeSample diff --git a/tests/models/test_match_megatron.py b/tests/models/test_match_megatron.py index 42a7c1f0d..4b057dabd 100644 --- a/tests/models/test_match_megatron.py +++ b/tests/models/test_match_megatron.py @@ -1,21 +1,25 @@ import os +import pathlib +import struct import typing import numpy as np import pytest +import yaml from fast_llm.config import Field, FieldHint, config_class from fast_llm.data.dataset.abstract import SampledDataset from fast_llm.data.dataset.config import MemmapDatasetConfig, SampledDatasetConfig from fast_llm.data.dataset.gpt.config import GPTSamplingData -from fast_llm.data.dataset.gpt.legacy_memmap import LegacyMemmapDataset +from fast_llm.data.dataset.gpt.legacy_memmap import MEMMAP_DTYPES, MEMMAP_INDEX_HEADER, LegacyMemmapDataset from fast_llm.data.dataset.sampled import logger from fast_llm.data.sample.language_model import LanguageModelSample +from fast_llm.engine.config_utils.data_type import DataType from fast_llm.utils import Assert from tests.utils.compare_tensor_logs import CompareConfig -from tests.utils.dataset import get_model_test_dataset +from tests.utils.dataset import get_test_dataset_samples from tests.utils.distributed_configs import DistributedTestingConfig -from tests.utils.global_variables import MODEL_DATASET_PATH +from tests.utils.global_variables import DATASET_CACHE, MODEL_TEST_VOCAB_SIZE from tests.utils.model_configs import ModelTestingGroup from tests.utils.utils import requires_cuda @@ -26,6 +30,20 @@ except ImportError: _extension_available = False +MEGATRON_DATASET_PREFIX = DATASET_CACHE / "megatron_dataset/dataset" + + +def get_megatron_test_dataset(prefix: pathlib.Path = MEGATRON_DATASET_PREFIX): + if not ( + prefix.with_suffix(".idx").is_file() + and prefix.with_suffix(".bin").is_file() + and prefix.parent.joinpath("fast_llm_config.yaml").is_file() + ): + MegatronMemmapDataset.write_dataset(prefix, get_test_dataset_samples(vocab_size=MODEL_TEST_VOCAB_SIZE)) + yaml.safe_dump( + {"type": "memmap", "path": prefix.name}, prefix.parent.joinpath("fast_llm_config.yaml").open("w") + ) + @requires_cuda @pytest.mark.model_testing_group(ModelTestingGroup.megatron) @@ -35,11 +53,12 @@ def test_megatron(run_distributed_script, model_testing_config, run_test_script_ # Prevent Megatron from complaining. env["CUDA_DEVICE_MAX_CONNECTIONS"] = "1" env["NVTE_FLASH_ATTN"] = "0" - get_model_test_dataset() + get_megatron_test_dataset() run_distributed_script( [ "Megatron-LM/pretrain_gpt.py", *model_testing_config.megatron_args, + f"--data-path={MEGATRON_DATASET_PREFIX}", f"--structured-logs-dir={path}", f"--data-cache-path={path}", ], @@ -69,7 +88,7 @@ def test_match_megatron(run_test_script_for_all_models, model_testing_config, co compare="megatron", config_args=[ "model.distributed.compute_dtype=fp32", - f'data.datasets.training={{"type":"megatron","path":{MODEL_DATASET_PATH}}}', + f'data.datasets.training={{"type":"megatron","path":{MEGATRON_DATASET_PREFIX}}}', "data.sampling.seed=1234", "model.base_model.use_megatron_initialization=True", ], @@ -97,6 +116,66 @@ class MegatronMemmapDataset(LegacyMemmapDataset): def sample(self, sampling: GPTSamplingData) -> "MegatronSampledIndexedDataset": return MegatronSampledIndexedDataset(self, sampling) + @classmethod + def write_dataset( + cls, + prefix: pathlib.Path | str, + documents: typing.Iterable[LanguageModelSample], + ) -> None: + # Initialize metadata + dtype = None + num_documents = 0 + lengths = [] + pointers = [] + offset = 0 + + prefix = pathlib.Path(prefix) + prefix.parent.mkdir(parents=True, exist_ok=True) + + # Write the binary data file (.bin) lazily + with prefix.with_suffix(".bin").open("wb") as bin_stream: + for document in documents: + token_ids = document.tokens.tokens + # Infer dtype from the first document + if dtype is None: + dtype = token_ids.dtype + assert dtype is not None, "Document dtype could not be inferred from the data." + + # Ensure all documents have the same dtype + assert token_ids.dtype == dtype, f"Expected dtype {dtype}, got {token_ids.dtype}." + + # Write document to binary file + bin_stream.write(token_ids.numpy().tobytes(order="C")) + + # Update metadata + doc_length = len(token_ids) + lengths.append(doc_length) + pointers.append(offset) + offset += doc_length * dtype.itemsize + num_documents += 1 + + # Finalize metadata arrays + lengths = np.array(lengths, dtype=np.int32) + pointers = np.array(pointers, dtype=np.int64) + + # Write the index file (.idx) + with prefix.with_suffix(".idx").open("wb") as idx_stream: + idx_stream.write(MEMMAP_INDEX_HEADER) + # Version + idx_stream.write(struct.pack(" list[LanguageModelSample]: + import transformers + + download_santacoder_tokenizer() + + texts = "".join(random.Random(seed).choices(characters, k=num_tokens)).splitlines() + tokenizer = transformers.AutoTokenizer.from_pretrained(TOKENIZER_PATH) + + samples = [ + LanguageModelSample( + TokenSample(torch.from_numpy(np.array(tokenizer(document)["input_ids"], dtype=np.uint16) % vocab_size)), + ) + for document in texts + ] + if max_spans > 0: + spans = get_random_spans( + len(samples), max_spans, np.array([[max(len(sample), 1)] for sample in samples]), seed + ) + for sample, sample_spans in zip(samples, spans, strict=True): + sample.loss_masking_spans = RangeSample(sample_spans, len(sample)) + return samples + + def get_test_dataset( path: pathlib.Path = DATASET_PATH, seed: int = 1234, @@ -45,29 +74,16 @@ def get_test_dataset( vocab_size: int = TEST_VOCAB_SIZE, max_spans: int = 0, ): - download_santacoder_tokenizer() config_path = path.parent.joinpath("fast_llm_config.yaml") if not (path.is_file() and config_path.is_file()): - import transformers - - texts = "".join(random.Random(seed).choices(characters, k=num_tokens)).splitlines() - tokenizer = transformers.AutoTokenizer.from_pretrained(TOKENIZER_PATH) - - samples = [ - LanguageModelSample( - TokenSample( - torch.from_numpy(np.array(tokenizer(document)["input_ids"], dtype=np.uint16) % vocab_size) - ), - ) - for document in texts - ] - if max_spans > 0: - spans = get_random_spans( - len(samples), max_spans, np.array([[max(len(sample), 1)] for sample in samples]), seed - ) - for sample, sample_spans in zip(samples, spans, strict=True): - sample.loss_masking_spans = RangeSample(sample_spans, len(sample)) + samples = get_test_dataset_samples( + seed=seed, + num_tokens=num_tokens, + characters=characters, + vocab_size=vocab_size, + max_spans=max_spans, + ) MemmapDataset.write_dataset(path, samples, LanguageModelWriter) yaml.safe_dump({"type": "memmap", "path": path.name}, config_path.open("w")) diff --git a/tests/utils/global_variables.py b/tests/utils/global_variables.py index c62903a6c..ea770be0a 100644 --- a/tests/utils/global_variables.py +++ b/tests/utils/global_variables.py @@ -37,13 +37,13 @@ def set_testing_global_variables(): TOKENIZER_PATH = SHARED_RESULT_PATH / "tokenizer" TOKENIZER_FILE = TOKENIZER_PATH / "tokenizer.json" DATASET_CACHE = SHARED_RESULT_PATH / "dataset" -DATASET_PATH = DATASET_CACHE / "common_dataset.fast_llm_dataset" -DATASET_WITH_SPANS_PATH = DATASET_CACHE / "dataset_with_spans.fast_llm_dataset" +DATASET_PATH = DATASET_CACHE / "common_dataset/dataset.fast_llm_dataset" +DATASET_WITH_SPANS_PATH = DATASET_CACHE / "dataset_with_spans/dataset.fast_llm_dataset" DATASET_SAMPLING_CACHE = TEST_RESULTS_PATH / "dataset_sampling_cache" TEST_VOCAB_SIZE = 8192 # Random lowercase: 80.7% (3.1% each); space: 18.6%; doc end: 0.6% TEST_CHARACTERS = (string.ascii_lowercase) * 5 + " " * 30 + "\n" TEST_DATASET_TOKENS = 1000000 -MODEL_DATASET_PATH = DATASET_CACHE / "model_dataset.fast_llm_dataset" +MODEL_DATASET_PATH = DATASET_CACHE / "model_dataset/dataset.fast_llm_dataset" MODEL_TEST_VOCAB_SIZE = 384 diff --git a/tests/utils/model_configs.py b/tests/utils/model_configs.py index adcf84b18..ee9c2b730 100644 --- a/tests/utils/model_configs.py +++ b/tests/utils/model_configs.py @@ -279,7 +279,6 @@ def _update_and_add_testing_config( "--tokenizer-type=NullTokenizer", # Megatron messes with the vocab size, so we have to subtract 1. f"--vocab-size={MODEL_TEST_VOCAB_SIZE - 1}", - f"--data-path={MODEL_DATASET_PATH}", "--split=1,0,0", "--lr-decay-style=constant", # Initialization is set up to match MCore models (MCore inverts self-attn qkv and dense layers compared to original Megatron) From c5fa07214aab5fd230d9e62aaf6bd0a38e5e1588 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Wed, 29 Oct 2025 19:46:03 -0400 Subject: [PATCH 04/12] int64 --- fast_llm/data/dataset/memmap.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/fast_llm/data/dataset/memmap.py b/fast_llm/data/dataset/memmap.py index e51dfb40d..4b1930dd3 100644 --- a/fast_llm/data/dataset/memmap.py +++ b/fast_llm/data/dataset/memmap.py @@ -33,7 +33,7 @@ def _init(self, name: str, path: pathlib.Path | str) -> None: # Very file type. assert stream.read(len(FILE_HEADER)) == FILE_HEADER # Go to reader configs. - stream.seek(int.from_bytes(stream.read(4), signed=False)) + stream.seek(int.from_bytes(stream.read(8), signed=False)) # Read the reader config. reader_config = MemmapIndexDatasetReaderConfig.from_dict( json.loads(stream.read(int.from_bytes(stream.read(4), signed=False)).decode("utf-8")) @@ -91,7 +91,7 @@ def write_dataset( # Leave space for a pointer to the reader config. # We write the config at the end since we don't know it yet. start = stream.tell() - stream.seek(start + 4) + stream.seek(start + 8) # Write the data. reader_config = writer_class.write_dataset(stream, documents) # Write the reader config. @@ -101,5 +101,5 @@ def write_dataset( stream.write(reader_config_bytes) # Write a pointer to the reader config. stream.seek(start) - stream.write(config_offset.to_bytes(4, signed=False)) + stream.write(config_offset.to_bytes(8, signed=False)) return reader_config From cd286766f08b0a470da007628ddae27ddcc9f583 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Tue, 4 Nov 2025 21:29:42 -0500 Subject: [PATCH 05/12] Test and fix preparator --- fast_llm/data/config.py | 41 ---- fast_llm/data/dataset/gpt/config.py | 5 +- fast_llm/data/dataset/gpt/fim.py | 8 +- fast_llm/data/dataset/gpt/random.py | 1 + fast_llm/data/dataset/indexed.py | 3 +- fast_llm/data/preparator/config.py | 1 - fast_llm/data/preparator/gpt_memmap/config.py | 60 +++-- .../data/preparator/gpt_memmap/prepare.py | 189 +++++++-------- fast_llm/data/preprocessing/__init__.py | 0 fast_llm/data/preprocessing/tokenizer.py | 196 ++++++++++++++++ fast_llm/data/sample/abstract.py | 5 +- fast_llm/data/sample/language_model.py | 52 ++-- fast_llm/data/sample/range.py | 8 +- fast_llm/data/sample/token.py | 7 +- fast_llm/engine/config_utils/data_type.py | 6 +- fast_llm/engine/config_utils/runnable.py | 2 +- fast_llm/engine/evaluation/config.py | 2 +- fast_llm/utils.py | 23 +- tests/data/common.py | 87 +------ tests/data/test_blending.py | 90 +++---- tests/data/test_concatenate.py | 50 ++-- tests/data/test_dataset_from_file.py | 12 - tests/data/test_fim.py | 51 ++-- tests/data/test_loss_masking_spans.py | 78 ++++++ tests/data/test_memmap.py | 47 ---- tests/data/test_preference_spans.py | 105 +++++++++ tests/data/test_preparator.py | 197 ++++++++++++++++ tests/data/test_prepare_gpt_memmap.py | 211 ----------------- tests/data/test_random.py | 16 +- tests/data/test_sampling.py | 45 ++-- tests/data/test_slice.py | 56 ++--- tests/functional/test_functional.py | 11 +- tests/models/test_match_megatron.py | 20 +- tests/utils/dataset.py | 222 ++++++++++++------ tests/utils/global_variables.py | 14 +- tests/utils/model_configs.py | 8 +- 36 files changed, 1075 insertions(+), 854 deletions(-) create mode 100644 fast_llm/data/preprocessing/__init__.py create mode 100644 fast_llm/data/preprocessing/tokenizer.py delete mode 100644 tests/data/test_dataset_from_file.py create mode 100644 tests/data/test_loss_masking_spans.py delete mode 100644 tests/data/test_memmap.py create mode 100644 tests/data/test_preference_spans.py create mode 100644 tests/data/test_preparator.py delete mode 100644 tests/data/test_prepare_gpt_memmap.py diff --git a/fast_llm/data/config.py b/fast_llm/data/config.py index 633367c80..78bc20636 100644 --- a/fast_llm/data/config.py +++ b/fast_llm/data/config.py @@ -1,12 +1,4 @@ import enum -import pathlib -import typing - -from fast_llm.config import Config, Field, FieldHint, check_field, config_class -from fast_llm.utils import Assert - -if typing.TYPE_CHECKING: - from fast_llm.data.tokenizer import Tokenizer class MultiprocessingContext(str, enum.Enum): @@ -15,36 +7,3 @@ class MultiprocessingContext(str, enum.Enum): fork = "fork" # Safe but much slower. spawn = "spawn" - - -TokenizerFromFile = "TokenizerFromFile" - - -@config_class() -class TokenizerConfig(Config): - """ - Configuration for the tokenizer. - The tokenizer is needed for FIM and dataset preparation. - """ - - format: str = Field( - default="TokenizerFromFile", - desc="Unused.", - hint=FieldHint.deprecated, - valid=check_field(Assert.eq, TokenizerFromFile), - ) - path: pathlib.Path = Field( - default=None, - desc="Path to the tokenizer file.", - hint=FieldHint.core, - ) - bos_token: str | None = Field( - default=None, - desc="BOS token to use if the tokenizer doesn't define one; must be an existing token.", - hint=FieldHint.core, - ) - - def get_tokenizer(self) -> "Tokenizer": - from fast_llm.data.tokenizer import Tokenizer - - return Tokenizer(self) diff --git a/fast_llm/data/dataset/gpt/config.py b/fast_llm/data/dataset/gpt/config.py index 9ff6654c2..7583345c3 100644 --- a/fast_llm/data/dataset/gpt/config.py +++ b/fast_llm/data/dataset/gpt/config.py @@ -6,9 +6,9 @@ import yaml from fast_llm.config import Config, Field, FieldHint, check_field, config_class, skip_valid_if_none -from fast_llm.data.config import TokenizerConfig from fast_llm.data.dataset.abstract import SamplableDataset, SampledDataset from fast_llm.data.dataset.config import SamplableDatasetConfig, SampledDatasetConfig, SamplingData, SamplingParameters +from fast_llm.data.preprocessing.tokenizer import TokenizerConfig from fast_llm.data.sample.language_model import LanguageModelSample from fast_llm.utils import Assert @@ -23,7 +23,8 @@ class GPTSamplingParameters(SamplingParameters): Sampling parameters set externally to the dataset and data, ex. determined by the trainer or model. """ - vocab_size: int + # TODO: Only used for random dataset. Remove? Or use as safety check? + vocab_size: int | None = None use_loss_masking_spans: bool = False use_preference_loss_spans: bool = False diff --git a/fast_llm/data/dataset/gpt/fim.py b/fast_llm/data/dataset/gpt/fim.py index 1fde74530..d36384ee5 100644 --- a/fast_llm/data/dataset/gpt/fim.py +++ b/fast_llm/data/dataset/gpt/fim.py @@ -5,6 +5,7 @@ from fast_llm.data.dataset.gpt.config import FimConfig, GPTSamplingData from fast_llm.data.sample.language_model import LanguageModelSample from fast_llm.data.sample.token import TokenSample +from fast_llm.engine.config_utils.data_type import DataType from fast_llm.engine.distributed.config import MAX_SEED @@ -168,9 +169,10 @@ def _fim_permute_sequence( middle = contents[boundaries[0] : boundaries[1]] suffix = contents[boundaries[1] :] - prefix = np.array([*self._tokenizer.tokenize(prefix, end=False)], dtype=sequence.dtype) - middle = np.array([*self._tokenizer.tokenize(middle, begin=False, end=False)], dtype=sequence.dtype) - suffix = np.array([*self._tokenizer.tokenize(suffix, begin=False)], dtype=sequence.dtype) + data_type = DataType.from_numpy(sequence.dtype) + prefix = self._tokenizer.tokenize(prefix, end=False, data_type=data_type).numpy() + middle = self._tokenizer.tokenize(middle, begin=False, end=False, data_type=data_type).numpy() + suffix = self._tokenizer.tokenize(suffix, begin=False, data_type=data_type).numpy() # here we truncate each given segment to fit the same length as it was before # A consequence is that we never reach the end of a file? diff --git a/fast_llm/data/dataset/gpt/random.py b/fast_llm/data/dataset/gpt/random.py index 463c5a7d6..f1e73c595 100644 --- a/fast_llm/data/dataset/gpt/random.py +++ b/fast_llm/data/dataset/gpt/random.py @@ -29,6 +29,7 @@ def __init__(self, sampling: GPTSamplingData, name: str): self._name = name self._seed = sampling.config.seed self._parameters = sampling.parameters + assert self._parameters.vocab_size is not None # TODO: Support? assert not self._parameters.use_loss_masking_spans assert not self._parameters.use_preference_loss_spans diff --git a/fast_llm/data/dataset/indexed.py b/fast_llm/data/dataset/indexed.py index 5d6636f7f..b2e6f7e3d 100644 --- a/fast_llm/data/dataset/indexed.py +++ b/fast_llm/data/dataset/indexed.py @@ -117,11 +117,12 @@ def __init__( def __len__(self) -> int: return self._dataset_splits[-1].item() + @property def num_tokens(self) -> int: """ Number of tokens in the dataset. """ - return sum(len(dataset) for dataset in self._datasets) + return sum(dataset.num_tokens for dataset in self._datasets) def get_document_sizes(self) -> torch.Tensor: # TODO: This can be really big. diff --git a/fast_llm/data/preparator/config.py b/fast_llm/data/preparator/config.py index 160fccafc..a774fc3de 100644 --- a/fast_llm/data/preparator/config.py +++ b/fast_llm/data/preparator/config.py @@ -7,7 +7,6 @@ @config_class(registry=True, dynamic_type={RunnableConfig: "prepare"}) class DatasetPreparatorConfig(RunnableConfig): - preparator_name: typing.ClassVar[str] @classmethod def get_dataset_preparator_class(cls) -> type["DatasetPreparator"]: diff --git a/fast_llm/data/preparator/gpt_memmap/config.py b/fast_llm/data/preparator/gpt_memmap/config.py index a54465080..9bf292033 100644 --- a/fast_llm/data/preparator/gpt_memmap/config.py +++ b/fast_llm/data/preparator/gpt_memmap/config.py @@ -4,8 +4,8 @@ import typing from fast_llm.config import Config, Field, FieldHint, check_field, config_class -from fast_llm.data.config import TokenizerConfig from fast_llm.data.preparator.config import DatasetPreparatorConfig +from fast_llm.data.preprocessing.tokenizer import TokenizerConfig from fast_llm.engine.config_utils.data_type import DataType from fast_llm.engine.config_utils.runnable import RunnableConfig from fast_llm.utils import Assert @@ -16,50 +16,53 @@ @config_class() class LanguageModelSourceConfig(Config): - text_column: str = Field( + """ + A schema holding the name of each relevant column in the dataset. + Setting optional entries will enable the associated feature. + """ + + text: str = Field( default="text", desc="Field of the dataset to use.", hint=FieldHint.optional, ) - loss_masking_spans_column: None | str = Field( + loss_masking_spans: None | str = Field( default=None, desc="Field containing character spans to mask for loss computation", hint=FieldHint.optional ) - chosen_spans_column: None | str = Field( + chosen_span: None | str = Field( default=None, desc="Field containing chosen text for preference optimization", hint=FieldHint.optional ) - rejected_spans_column: None | str = Field( + rejected_span: None | str = Field( default=None, desc="Field containing rejected text for preference optimization", hint=FieldHint.optional ) @functools.cached_property def columns(self) -> list[str]: - columns = [self.text_column] + columns = [self.text] if self.has_loss_masking_span: - columns.append(self.loss_masking_spans_column) + columns.append(self.loss_masking_spans) if self.has_preference_spans: - columns.extend([self.chosen_spans_column, self.rejected_spans_column]) + columns.extend([self.chosen_span, self.rejected_span]) return columns @functools.cached_property def has_loss_masking_span(self) -> bool: - return self.loss_masking_spans_column is not None + return self.loss_masking_spans is not None @functools.cached_property def has_preference_spans(self) -> bool: - Assert.eq(self.chosen_spans_column is None, self.rejected_spans_column is None) - return self.chosen_spans_column is not None + Assert.eq(self.chosen_span is None, self.rejected_span is None) + return self.chosen_span is not None def _validate(self): super()._validate() - if self.has_loss_masking_span != self.rejected_spans_column is not None: - raise ValueError(f"Both chosen and rejected loss masking spans must be specified if one is specified.") if self.has_preference_spans and self.has_loss_masking_span: raise ValueError(f"Can not enable both loss masking and preference spans.") @config_class() class GPTHuggingfaceDatasetConfig(Config): - path: str = Field( + path: str | pathlib.Path = Field( default=None, desc="Name or path of the dataset.", hint=FieldHint.core, @@ -104,6 +107,11 @@ class GPTHuggingfaceDatasetConfig(Config): desc="Disable disk space check. Useful for environments where disk space is not accurately reported.", hint=FieldHint.optional, ) + load_from_disk: bool = Field( + default=False, + desc="Use the `load_from_disk` method for datasets saved with `save_to_disk`.", + hint=FieldHint.feature, + ) @config_class() @@ -141,7 +149,6 @@ def _validate(self) -> None: @config_class(dynamic_type={RunnableConfig: "prepare_gpt_memmap", DatasetPreparatorConfig: "gpt_memmap"}) class GPTMemmapDatasetPreparatorConfig(DatasetPreparatorConfig): - preparator_name: typing.ClassVar[str] = "gpt_memmap" output_path: pathlib.Path = Field( default=None, desc="Output directory for the processed dataset.", @@ -151,27 +158,14 @@ class GPTMemmapDatasetPreparatorConfig(DatasetPreparatorConfig): desc="Configuration for distributed processing.", hint=FieldHint.feature, ) - tokens_per_shard: int = Field( - default=10**9, - desc="Approximate number of tokens per shard.", + documents_per_shard: int = Field( + default=10**6, + desc="Target number of documents per shard.", hint=FieldHint.feature, - valid=check_field(Assert.geq, 10**5), - ) - loading_workers: int = Field( - default=1, - desc="Number of workers in load_dataset() call.", - hint=FieldHint.optional, - valid=check_field(Assert.geq, 1), - ) - tokenize_workers: int = Field( - default=1, - desc="Number of workers for tokenization.", - hint=FieldHint.optional, - valid=check_field(Assert.geq, 1), ) - saving_workers: int = Field( + num_workers: int = Field( default=1, - desc="Number of processes for saving the data.", + desc="Number of parallel workers.", hint=FieldHint.optional, valid=check_field(Assert.geq, 1), ) diff --git a/fast_llm/data/preparator/gpt_memmap/prepare.py b/fast_llm/data/preparator/gpt_memmap/prepare.py index d3d15fa64..18d4d46e2 100644 --- a/fast_llm/data/preparator/gpt_memmap/prepare.py +++ b/fast_llm/data/preparator/gpt_memmap/prepare.py @@ -1,3 +1,5 @@ +import collections +import enum import json import logging import math @@ -25,17 +27,24 @@ from fast_llm.data.dataset.memmap import MemmapDataset from fast_llm.data.preparator.config import DatasetPreparator from fast_llm.data.preparator.gpt_memmap.config import GPTMemmapDatasetPreparatorConfig, LanguageModelSourceConfig +from fast_llm.data.preprocessing.tokenizer import Tokenizer from fast_llm.data.sample.abstract import MemmapIndexDatasetReaderConfig from fast_llm.data.sample.language_model import LanguageModelSample, LanguageModelWriter from fast_llm.data.sample.range import RangeSample from fast_llm.data.sample.token import TokenSample -from fast_llm.data.tokenizer import Tokenizer from fast_llm.engine.config_utils.data_type import DataType, get_unsigned_integer_type +from fast_llm.engine.config_utils.run import log_main_rank from fast_llm.utils import Assert, normalize_probabilities, padded_cumsum logger = logging.getLogger(__name__) +class SpanType(enum.StrEnum): + loss_masking = "loss_masking" + chosen = "chosen" + rejected = "rejected" + + class GPTMemmapDatasetPreparator[ConfigType: GPTMemmapDatasetPreparatorConfig](DatasetPreparator[ConfigType]): _tokenizer: Tokenizer _data_type: DataType @@ -46,30 +55,19 @@ def __init__(self, config: ConfigType): super().__init__(config) self._source_schema: LanguageModelSourceConfig = self._config.dataset.source_schema - def _save_shard( - self, args: tuple[int, datasets.Dataset] - ) -> tuple[MemmapDatasetConfig, MemmapIndexDatasetReaderConfig]: - shard_index, shard_dataset = args - file_name = f"shard_{self._config.distributed.rank}_{shard_index}.fast_llm_dataset" - - reader_config = MemmapDataset.write_dataset( - self._config.output_path / file_name, - tqdm.tqdm((sample["sample"] for sample in shard_dataset), desc=f"Saving shard {shard_index}", unit="docs"), - LanguageModelWriter, - ) - - return MemmapDatasetConfig.from_dict({"type": "memmap", "path": file_name}), reader_config - def _load_dataset(self) -> datasets.Dataset: - dataset = datasets.load_dataset( - path=self._config.dataset.path, - name=self._config.dataset.config_name, - data_dir=self._config.dataset.data_directory, - data_files=self._config.dataset.data_files, - split=self._config.dataset.split, - num_proc=self._config.loading_workers, - trust_remote_code=self._config.dataset.trust_remote_code, - ) + if self._config.dataset.load_from_disk: + dataset = datasets.load_from_disk(self._config.dataset.path)[self._config.dataset.split] + else: + dataset = datasets.load_dataset( + path=self._config.dataset.path, + name=self._config.dataset.config_name, + data_dir=self._config.dataset.data_directory, + data_files=self._config.dataset.data_files, + split=self._config.dataset.split, + num_proc=self._config.num_workers, + trust_remote_code=self._config.dataset.trust_remote_code, + ) assert isinstance(dataset, datasets.Dataset) return dataset @@ -137,6 +135,7 @@ def run(self) -> None: # Initialize distributed processing if self._config.distributed.world_size > 1: + log_main_rank(f"> Initializing distributed process groups ...") torch.distributed.init_process_group( backend=self._config.distributed.backend, rank=self._config.distributed.rank, @@ -146,31 +145,18 @@ def run(self) -> None: # Prepare output directory self._config.output_path.mkdir(parents=True, exist_ok=True) - downloaded = pathlib.Path(self._config.dataset.path).is_dir() - if self._config.distributed.world_size > 1: - torch.distributed.barrier() - - if downloaded: - # Dataset is already downloaded, load from disk + log_main_rank(f"> Loading dataset `{self._config.dataset.path}` ...") + if self._config.distributed.world_size == 1: + dataset = self._load_dataset() + elif self._config.distributed.rank == 0: + # Load first on rank 0 to prevent parallel downloads. dataset = self._load_dataset() + torch.distributed.barrier() else: - # Dataset is not downloaded, download on rank 0 - if self._config.distributed.rank == 0: - dataset = self._load_dataset() - - # Synchronize processes to wait for the download to finish on rank 0 - if self._config.distributed.world_size > 1: - torch.distributed.barrier() - + torch.distributed.barrier() # Load the downloaded dataset on remaining ranks - if self._config.distributed.rank != 0: - dataset = self._load_dataset() - - # Synchronize processes to wait for the dataset to load on remaining ranks - if self._config.distributed.world_size > 1: - torch.distributed.barrier() + dataset = self._load_dataset() - assert isinstance(dataset, datasets.Dataset) dataset = dataset.shard( num_shards=self._config.distributed.world_size, index=self._config.distributed.rank, @@ -180,49 +166,45 @@ def run(self) -> None: if column_name not in dataset.column_names: raise ValueError(f"Dataset does not have field '{column_name}'.") - # Tokenize the dataset in parallel - prepared_dataset = dataset.map( - self._prepare_batch, - batched=True, - num_proc=self._config.tokenize_workers, - desc="Tokenizing batches", - ) - # Split dataset into shards based on number of tokens - num_shards = math.ceil( - sum(len(sample) for sample in prepared_dataset["samples"]) / self._config.tokens_per_shard - ) - shards = [ - (i, prepared_dataset.shard(num_shards=num_shards, index=i)) - for i in tqdm.tqdm(range(num_shards), desc="Creating shards") - ] + num_shards = math.ceil(len(dataset) / self._config.documents_per_shard) + shards = [(i, dataset.shard(num_shards=num_shards, index=i)) for i in range(num_shards)] + + log_main_rank(f"> Preparing samples on {self._config.num_workers} workers ...") # Use multiprocessing to save each shard in parallel on all ranks - with multiprocessing.Pool(processes=self._config.saving_workers) as pool: - dataset_and_reader_configs = pool.map(self._save_shard, shards) + with multiprocessing.Pool(processes=self._config.num_workers) as pool: + dataset_and_reader_configs = pool.map(self._prepare_shard, shards) + log_main_rank(f"> Generating dataset config ...") self.generate_config_yaml_for_sharded_dst(dataset_and_reader_configs) - def _prepare_batch(self, batch: dict[str, list[typing.Any]]) -> dict[str, list[LanguageModelSample]]: - # Gather values by sample using zip* - sample_column_values = zip(*(batch[column_name] for column_name in self._source_schema.columns)) - # Convert to dicts using column names. - sample_dicts = ( - {column_name: column_value for column_name, column_value in zip(self._source_schema.columns, sample_data)} - for sample_data in sample_column_values + def _prepare_shard( + self, args: tuple[int, datasets.Dataset] + ) -> tuple[MemmapDatasetConfig, MemmapIndexDatasetReaderConfig]: + shard_index, shard_dataset = args + file_name = f"shard_{self._config.distributed.rank}_{shard_index}.fast_llm_dataset" + + reader_config = MemmapDataset.write_dataset( + self._config.output_path / file_name, + ( + self._prepare_sample(sample) + for sample in tqdm.tqdm(shard_dataset, desc=f"Saving shard {shard_index}", unit="docs") + ), + LanguageModelWriter, ) - # Prepare each sample, wrap in dict for the `Dataset` interface - return {"samples": [self._prepare_sample(sample_dict) for sample_dict in sample_dicts]} + return MemmapDatasetConfig.from_dict({"type": "memmap", "path": file_name}), reader_config def _prepare_sample(self, sample: dict[str, typing.Any]) -> LanguageModelSample: - text = sample[self._source_schema.text_column] + # TODO: ======= Extract so we can use elsewhere? (ex. inference) ====== + text = sample[self._source_schema.text] all_spans = [] if self._source_schema.has_loss_masking_span: # TODO: ====== What is the exact input format? ====== # Spans are typically stored in the (begin, last) format. We convert to (begin, end) range format. loss_masking_spans = _sort_spans( - (begin, last + 1) - for begin, last in np.array(sample[self._source_schema.loss_masking_spans_column], dtype=np.int32) + (SpanType.loss_masking, (begin, last + 1)) + for begin, last in np.array(sample[self._source_schema.loss_masking_spans], dtype=np.int32) .reshape(-1, 2) .tolist() ) @@ -230,37 +212,58 @@ def _prepare_sample(self, sample: dict[str, typing.Any]) -> LanguageModelSample: if self._source_schema.has_preference_spans: # TODO: ===== Was `self._config.dataset.field` (bug?) ====== - full_chosen_text = ( - text + sample[self._source_schema.chosen_spans_column] + self._tokenizer.tokenizer.eos_token - ) - full_rejected_text = ( - self._tokenizer.tokenizer.bos_token + text + sample[self._source_schema.rejected_spans_column] - ) + full_chosen_text = text + sample[self._source_schema.chosen_span] + self._tokenizer.tokenizer.eos_token + full_rejected_text = self._tokenizer.tokenizer.bos_token + text + sample[self._source_schema.rejected_span] # compute chosen span - chosen_spans = [[len(text), len(full_chosen_text)]] + chosen_spans = [(SpanType.chosen, (len(text), len(full_chosen_text)))] # compute rejected span rejected_span = [ - [ - len(full_chosen_text) + len(self._tokenizer.tokenizer.bos_token) + len(text), - len(full_chosen_text) + len(full_rejected_text), - ] + ( + SpanType.rejected, + ( + len(full_chosen_text) + len(self._tokenizer.tokenizer.bos_token) + len(text), + len(full_chosen_text) + len(full_rejected_text), + ), + ) ] # pack texts text = full_chosen_text + full_rejected_text all_spans.extend(chosen_spans + rejected_span) - tokens = torch.tensor( - self._tokenizer.tokenize_with_spans(text, True, True, spans=_sort_spans(all_spans)), - dtype=self._data_type.torch, + # Sort the spans by location (begin), keeping track of their type. + # Note: overlapping spans are not supported (explicit assertion in the tokenizer). + span_types, spans = zip(*_sort_spans(all_spans)) if all_spans else ([], []) + # Tokenize the text, and determine the span locations in the tokenized text. + tokens, token_spans = self._tokenizer.tokenize_with_spans( + text, True, True, text_spans=spans, data_type=self._data_type ) + + # Gather token spans by type. + token_spans_by_type = collections.defaultdict(list) + for span_type, token_span in zip(span_types, token_spans, strict=True): + token_spans_by_type[span_type].append(token_span) + sample_size = len(tokens) return LanguageModelSample( TokenSample(tokens, [sample_size]), - RangeSample(loss_masking_spans, sample_size) if self._source_schema.has_loss_masking_span else None, - RangeSample(chosen_spans, sample_size) if self._source_schema.has_preference_spans else None, - RangeSample(rejected_span, sample_size) if self._source_schema.has_preference_spans else None, + ( + RangeSample(token_spans_by_type[SpanType.loss_masking], sample_size) + if self._source_schema.has_loss_masking_span + else None + ), + ( + RangeSample(token_spans_by_type[SpanType.chosen], sample_size) + if self._source_schema.has_preference_spans + else None + ), + ( + # `tokenize_with_spans` excludes the final eod token from the rejected span, but we want to include it. + RangeSample([(begin, end + 1) for begin, end in token_spans_by_type[SpanType.rejected]], sample_size) + if self._source_schema.has_preference_spans + else None + ), ) def generate_config_yaml_for_sharded_dst( @@ -402,8 +405,8 @@ def _split_and_blend_dataset_configs( return dataset_splits -def _sort_spans(spans: typing.Iterable[tuple[int, int]]) -> list[tuple[int, int]]: - return sorted(spans, key=lambda span: span[0]) +def _sort_spans(spans: typing.Iterable[tuple[SpanType, tuple[int, int]]]) -> list[tuple[SpanType, tuple[int, int]]]: + return sorted(spans, key=lambda span: span[1][0]) def _get_nearest_split(cumsum: np.ndarray, value: float) -> int: diff --git a/fast_llm/data/preprocessing/__init__.py b/fast_llm/data/preprocessing/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/fast_llm/data/preprocessing/tokenizer.py b/fast_llm/data/preprocessing/tokenizer.py new file mode 100644 index 000000000..70291bcaa --- /dev/null +++ b/fast_llm/data/preprocessing/tokenizer.py @@ -0,0 +1,196 @@ +import pathlib +import typing + +from fast_llm.config import Config, Configurable, Field, FieldHint, config_class +from fast_llm.engine.config_utils.data_type import DataType +from fast_llm.engine.config_utils.run import log_main_rank +from fast_llm.utils import Assert + +if typing.TYPE_CHECKING: + import numpy as np + import torch + + +@config_class() +class TokenizerConfig(Config): + """ + Configuration for the tokenizer. + The tokenizer is needed for FIM and dataset preparation. + """ + + path: pathlib.Path = Field( + default=None, + desc="Path to the tokenizer file.", + hint=FieldHint.core, + ) + bos_token: str | None = Field( + default=None, + desc="BOS token to use if the tokenizer doesn't define one; must be an existing token.", + hint=FieldHint.core, + ) + max_vocab_size: int | None = Field( + default=None, + desc="Constrain output tokens to a specific range. Used for testing.", + hint=FieldHint.testing, + ) + + def get_tokenizer(self) -> "Tokenizer": + from fast_llm.data.preprocessing.tokenizer import Tokenizer + + return Tokenizer(self) + + +class Tokenizer[ConfigType: TokenizerConfig](Configurable[ConfigType]): + """ + A wrapper around Huggingface (transformers) tokenizer. + """ + + def __init__(self, config: ConfigType): + super().__init__(config) + from transformers import AutoTokenizer + + log_main_rank(f"> loading tokenizer from {config.path} ...") + self.tokenizer = AutoTokenizer.from_pretrained( + pretrained_model_name_or_path=self._config.path, + errors="replace", + max_len=None, + trust_remote_code=True, + use_fast=True, + ) + if self._config.bos_token is not None: + self.tokenizer.bos_token = self._config.bos_token + if self.tokenizer.eos_token_id is None: + raise ValueError("Tokenizer does not have an EOS token.") + if self.tokenizer.bos_token_id is None: + raise ValueError("Tokenizer does not have an BOS token.") + self.eod_id = self.tokenizer.eos_token_id + self.bod_id = self.tokenizer.bos_token_id + + @property + def vocab_size(self) -> int: + return len(self.tokenizer) + + @property + def vocab(self) -> dict[str, int]: + return self.tokenizer.vocab + + @property + def inv_vocab(self) -> dict[int, str]: + return self._inv_vocab + + def tokenize( + self, text: str, begin: bool = True, end: bool = True, data_type: DataType = DataType.int64 + ) -> "torch.Tensor": + import torch + + tokens = torch.tensor( + ([self.bod_id] if begin else []) + + self.tokenizer.encode(text, add_special_tokens=False) + + ([self.eod_id] if end else []), + dtype=data_type.torch, + ) + if self._config.max_vocab_size is not None: + tokens %= self._config.max_vocab_size + return tokens + + def tokenize_with_spans( + self, + text: str, + begin: bool = True, + end: bool = True, + *, + text_spans: list[tuple[int, int]], + data_type: DataType = DataType.int64, + ) -> tuple["torch.Tensor", list[tuple[int, int]]]: + """ + Perform span-aware tokenization and return the tokenized input_ids along with token spans. + """ + if not text_spans: + return self.tokenize(text, begin, end, data_type=data_type), [] + input_ids, token_splits = self.tokenize_with_splits( + text, begin, end, text_splits=[split for splits in text_spans for split in splits], data_type=data_type + ) + return input_ids, [(begin, end) for begin, end in zip(token_splits[::2], token_splits[1::2], strict=True)] + + def tokenize_with_splits( + self, + text: str, + begin: bool = True, + end: bool = True, + *, + text_splits: list[int], + data_type: DataType = DataType.int64, + ) -> tuple["torch.Tensor", list[int]]: + if not text_splits: + return self.tokenize(text, begin, end, data_type=data_type), [] + import torch + + Assert.eq(sorted(text_splits), text_splits) + input_ids = [] + text_splits = [0, *text_splits, len(text)] + token_splits = [] + total_tokens = 0 + + for i, (split_begin, split_end) in enumerate(zip(text_splits[:-1], text_splits[1:])): + input_ids.append( + split_tokens := self.tokenize( + text[split_begin:split_end], + begin and i == 0, + end and i == len(text_splits) - 2, + data_type=data_type, + ) + ) + total_tokens += len(split_tokens) + token_splits.append(total_tokens) + + return torch.cat(input_ids), token_splits[:-1] + + def detokenize( + self, tokens: "int | list[int] | np.ndarray | torch.Tensor", begin: bool = False, end: bool = False + ) -> str: + tokens = self._remove_delimiters(tokens, begin, end) + return self.tokenizer.decode(tokens) + + def detokenize_with_spans( + self, tokens: "torch.Tensor", begin: bool = False, end: bool = False, *, token_spans: list[tuple[int, int]] + ) -> tuple[str, list[tuple[int, int]]]: + if not token_spans: + return self.detokenize(tokens, begin, end), [] + text, text_splits = self.detokenize_with_splits( + tokens, begin, end, token_splits=[split for splits in token_spans for split in splits] + ) + return text, [(begin, end) for begin, end in zip(text_splits[::2], text_splits[1::2], strict=True)] + + def detokenize_with_splits( + self, tokens: "torch.Tensor", begin: bool = False, end: bool = False, *, token_splits: list[int] + ) -> tuple[str, list[int]]: + if not token_splits: + return self.detokenize(tokens, begin, end), [] + Assert.eq(sorted(token_splits), token_splits) + tokens = self._remove_delimiters(tokens, begin, end) + texts = [] + token_splits = [0, *(token_split - begin for token_split in token_splits), len(tokens)] + text_splits = [] + total_characters = 0 + + for i, (split_begin, split_end) in enumerate(zip(token_splits[:-1], token_splits[1:])): + texts.append(split_text := self.detokenize(tokens[split_begin:split_end])) + total_characters += len(split_text) + text_splits.append(total_characters) + + return "".join(texts), text_splits[:-1] + + def _remove_delimiters( + self, token_ids: "int | list[int] | np.ndarray | torch.Tensor", begin: bool = False, end: bool = False + ): + if begin: + Assert.eq(token_ids[0], self.bod_id) + token_ids = token_ids[1:] + if end: + Assert.eq(token_ids[-1], self.eod_id) + token_ids = token_ids[:-1] + return token_ids + + @property + def eod(self): + return self.eod_id diff --git a/fast_llm/data/sample/abstract.py b/fast_llm/data/sample/abstract.py index 0b2e324c3..aaa321efd 100644 --- a/fast_llm/data/sample/abstract.py +++ b/fast_llm/data/sample/abstract.py @@ -210,8 +210,9 @@ def write(self, document: Sample): assert hasattr(self, "_begin") and not hasattr(self, "_end") def __exit__(self, exc_type, exc_val, exc_tb): - self._stream.write(self._get_config_class().footer) - self._end = self._stream.tell() + if exc_type is None: + self._stream.write(self._get_config_class().footer) + self._end = self._stream.tell() if self._owns_stream: self._stream.close() diff --git a/fast_llm/data/sample/language_model.py b/fast_llm/data/sample/language_model.py index 77cc6e8a2..6f485bf84 100644 --- a/fast_llm/data/sample/language_model.py +++ b/fast_llm/data/sample/language_model.py @@ -89,9 +89,9 @@ def to_samples(self) -> list[LanguageModelSample]: LanguageModelSample(tokens, loss_masking_spans, chosen_spans, rejected_spans) for tokens, loss_masking_spans, chosen_spans, rejected_spans in zip( self.tokens.to_samples(), - self.loss_masking_spans.to_samples(), - self.chosen_spans.to_samples(), - self.rejected_spans.to_samples(), + None if self.loss_masking_spans is None else self.loss_masking_spans.to_samples(), + None if self.chosen_spans is None else self.chosen_spans.to_samples(), + None if self.rejected_spans is None else self.rejected_spans.to_samples(), strict=True, ) ] @@ -237,27 +237,31 @@ def __exit__(self, exc_type, exc_val, exc_tb): self._chosen_spans_writer.__exit__(exc_type, exc_val, exc_tb) self._rejected_spans_writer.__exit__(exc_type, exc_val, exc_tb) - # A dummy config so we can verify the begin and end offsets. - config = self._get_config(self._begin, None) - _copy_chunked(self._path.joinpath("tokens"), self._stream, config.tokens.begin, config.tokens.end) - - if self._has_loss_masking_spans: - _copy_chunked( - self._path.joinpath("loss_masking_spans"), - self._stream, - config.loss_masking_spans.begin, - config.loss_masking_spans.end, - ) - if self._has_preference_spans: - _copy_chunked( - self._path.joinpath("chosen_spans"), self._stream, config.chosen_spans.begin, config.chosen_spans.end - ) - _copy_chunked( - self._path.joinpath("rejected_spans"), - self._stream, - config.rejected_spans.begin, - config.rejected_spans.end, - ) + if exc_type is None: + # A dummy config so we can verify the begin and end offsets. + config = self._get_config(self._begin, None) + _copy_chunked(self._path.joinpath("tokens"), self._stream, config.tokens.begin, config.tokens.end) + + if self._has_loss_masking_spans: + _copy_chunked( + self._path.joinpath("loss_masking_spans"), + self._stream, + config.loss_masking_spans.begin, + config.loss_masking_spans.end, + ) + if self._has_preference_spans: + _copy_chunked( + self._path.joinpath("chosen_spans"), + self._stream, + config.chosen_spans.begin, + config.chosen_spans.end, + ) + _copy_chunked( + self._path.joinpath("rejected_spans"), + self._stream, + config.rejected_spans.begin, + config.rejected_spans.end, + ) self._directory.cleanup() super().__exit__(exc_type, exc_val, exc_tb) diff --git a/fast_llm/data/sample/range.py b/fast_llm/data/sample/range.py index 92d5ce7fc..c3a035376 100644 --- a/fast_llm/data/sample/range.py +++ b/fast_llm/data/sample/range.py @@ -12,7 +12,7 @@ MemmapWriter, Sample, ) -from fast_llm.utils import get_unique +from fast_llm.utils import Assert, get_unique class RangeSample(Sample): @@ -88,7 +88,7 @@ def __init__(self, config: ConfigType, buffer: memoryview): self._buffer, dtype=torch.int32, count=self._config.num_ranges * 2, - ).reshape(-1, 2) + ).view(-1, 2) self._count_cumsums = torch.frombuffer( self._buffer, dtype=torch.int32, @@ -117,7 +117,9 @@ def write(self, document: RangeSample): self._count_cumsum.append(self._count_cumsum[-1] + len(document.ranges)) def __exit__(self, exc_type, exc_val, exc_tb): - self._stream.write(np.array(self._count_cumsum, dtype=np.int32).tobytes(order="C")) + if exc_type is None: + Assert.lt(self._count_cumsum[-1], np.iinfo(np.int32).max) + self._stream.write(np.array(self._count_cumsum, dtype=np.int32).tobytes(order="C")) super().__exit__(exc_type, exc_val, exc_tb) @classmethod diff --git a/fast_llm/data/sample/token.py b/fast_llm/data/sample/token.py index ae190658f..706b5053a 100644 --- a/fast_llm/data/sample/token.py +++ b/fast_llm/data/sample/token.py @@ -88,11 +88,11 @@ def to_device_(self, device: "torch.device | str"): @config_class(dynamic_type={MemmapReaderBaseConfig: "token"}) class TokenReaderConfig(MemmapReaderConfig): _abstract = False + header: typing.ClassVar[bytes] = b"token begin" + footer: typing.ClassVar[bytes] = b"token end" num_documents: int = Field() num_tokens: int = Field() data_type: DataType = Field() - header: typing.ClassVar[bytes] = b"token begin" - footer: typing.ClassVar[bytes] = b"token end" def __len__(self) -> int: return self.num_documents @@ -151,7 +151,8 @@ def write(self, document: TokenSample): self._size_cumsum.append(self._size_cumsum[-1] + len(document.tokens)) def __exit__(self, exc_type, exc_val, exc_tb): - self._stream.write(np.array(self._size_cumsum, dtype=np.int64).tobytes(order="C")) + if exc_type is None: + self._stream.write(np.array(self._size_cumsum, dtype=np.int64).tobytes(order="C")) super().__exit__(exc_type, exc_val, exc_tb) @classmethod diff --git a/fast_llm/engine/config_utils/data_type.py b/fast_llm/engine/config_utils/data_type.py index 1a0fed91b..27709a8bb 100644 --- a/fast_llm/engine/config_utils/data_type.py +++ b/fast_llm/engine/config_utils/data_type.py @@ -50,9 +50,13 @@ def from_torch(cls, dtype: "torch.dtype") -> "DataType": return _TORCH_DTYPE_MAP_INV[dtype] @classmethod - def from_numpy(cls, dtype: "np.dtype") -> "DataType": + def from_numpy(cls, dtype: "np.dtype | type[np.number]") -> "DataType": + import numpy as np + if not _NUMPY_DTYPE_MAP_INV: _set_numpy_dtype_map() + if isinstance(dtype, np.dtype): + dtype = dtype.type return _NUMPY_DTYPE_MAP_INV[dtype] @classmethod diff --git a/fast_llm/engine/config_utils/runnable.py b/fast_llm/engine/config_utils/runnable.py index 051163084..163a9459c 100644 --- a/fast_llm/engine/config_utils/runnable.py +++ b/fast_llm/engine/config_utils/runnable.py @@ -106,7 +106,7 @@ def _get_runnable(self) -> typing.Callable[[], None]: return self.run def run(self) -> None: - raise NotImplementedError() + self._get_runnable()() def _show[ T diff --git a/fast_llm/engine/evaluation/config.py b/fast_llm/engine/evaluation/config.py index f8dfd4825..df7ab0f51 100644 --- a/fast_llm/engine/evaluation/config.py +++ b/fast_llm/engine/evaluation/config.py @@ -2,7 +2,7 @@ import typing from fast_llm.config import Config, Field, FieldHint, check_field, config_class, skip_valid_if_none -from fast_llm.data.config import TokenizerConfig +from fast_llm.data.preprocessing.tokenizer import TokenizerConfig from fast_llm.engine.schedule.config import BatchConfig from fast_llm.utils import Assert diff --git a/fast_llm/utils.py b/fast_llm/utils.py index 1f9feceb4..83675ac74 100644 --- a/fast_llm/utils.py +++ b/fast_llm/utils.py @@ -161,21 +161,22 @@ def rms_close_relative(x, y, threshold, min_threshold=0): assert rms <= threshold, f"Rms diff too big ({rms:.3e} > {threshold:.3e}) between tensors {x} and {y}" @staticmethod - def all_equal(x, y): + def all_equal(x, *args): import torch # Make it work for lists and numpy arrays. x = torch.as_tensor(x) - y = torch.as_tensor(y) - - Assert.eq(x.shape, y.shape) - neq = x != y - if neq.any().item(): # noqa - index = None if x.numel() == 1 else torch.where(neq) # noqa - raise AssertionError( - f"Tensors have {index[0].numel()} different entries out of " - f"{x.numel()}: {x[index]} != {y[index]} at index {torch.stack(index, -1)}" - ) + for arg in args: + arg = torch.as_tensor(arg) + + Assert.eq(x.shape, arg.shape) + neq = x != arg + if neq.any().item(): # noqa + index = None if x.numel() == 1 else torch.where(neq) # noqa + raise AssertionError( + f"Tensors have {index[0].numel()} different entries out of " + f"{x.numel()}: {x[index]} != {arg[index]} at index {torch.stack(index, -1)}" + ) @staticmethod def all_different(x, y): diff --git a/tests/data/common.py b/tests/data/common.py index 7053666b8..ac8d8023c 100644 --- a/tests/data/common.py +++ b/tests/data/common.py @@ -4,26 +4,18 @@ import numpy as np import torch -from fast_llm.config import Field, FieldHint, NoAutoValidate, config_class +from fast_llm.config import NoAutoValidate from fast_llm.data.data.gpt.config import GPTDataConfig from fast_llm.data.data.gpt.data import GPTData from fast_llm.data.dataset.abstract import SampledDataset -from fast_llm.data.dataset.config import ( - IndexedDatasetConfig, - SampledDatasetConfig, - SamplingConfig, - SamplingParameters, - ShufflingType, -) +from fast_llm.data.dataset.config import SampledDatasetConfig, SamplingConfig, ShufflingType from fast_llm.data.dataset.gpt.config import GPTSamplingData, GPTSamplingParameters from fast_llm.data.dataset.indexed import IndexedDataset from fast_llm.data.dataset.sampled import SampledIndexedDataset -from fast_llm.data.sample.abstract import Sample from fast_llm.engine.distributed.config import DistributedConfig, PhaseType from fast_llm.engine.distributed.distributed import Distributed from fast_llm.models.gpt.config import GPTBatchConfig from fast_llm.utils import Assert, div -from tests.utils.global_variables import TEST_VOCAB_SIZE def get_sampling_data( @@ -33,7 +25,7 @@ def get_sampling_data( cache_directory: pathlib.Path | None = None, phase=PhaseType.training, sequence_length: int = 512, - vocab_size=TEST_VOCAB_SIZE, + vocab_size: int | None = None, gpu: bool = False, shuffle: ShufflingType = ShufflingType.epoch, truncate_documents=True, @@ -73,7 +65,7 @@ def get_test_data_and_compare_samples( shuffle: ShufflingType = ShufflingType.epoch, cache_directory: pathlib.Path | None = None, sequence_length: int = 512, - vocab_size=TEST_VOCAB_SIZE, + vocab_size: int | None = None, expected_samples: dict[str, list[list[int]]] | list[list[int]], ) -> GPTData: distributed_config = DistributedConfig(seed=87522) @@ -115,34 +107,21 @@ def get_test_data_and_compare_samples( return data -def compare_indexed_dataset( +def compare_indexed_dataset_tokens( dataset: IndexedDataset, length: int, num_tokens: int, expected_samples: dict[int, list[int]], - loss_masking_spans: dict[int, list[int]] | None = None, ) -> None: Assert.eq(len(dataset), length) sizes = dataset.get_document_sizes() - # Assert.eq(sizes.sum(), num_tokens) + Assert.eq(sizes.sum(), num_tokens, dataset.num_tokens) Assert.all_equal( [len(dataset.get_document(i).tokens.tokens) for i in range(min(len(dataset), 100))], sizes[: min(len(dataset), 100)], ) for i, expected_sample in expected_samples.items(): Assert.all_equal(dataset.get_document(i).tokens.tokens, np.array(expected_sample)) - if loss_masking_spans: - for i, loss_masking_span in loss_masking_spans.items(): - print(i) - Assert.eq( - dataset.get_document( - i, - parameters=GPTSamplingParameters( - num_samples=0, sequence_length=0, vocab_size=0, use_loss_masking_spans=True - ), - ).loss_masking_spans.ranges, - loss_masking_spans[i], - ) def compare_sampled_dataset(sampled: SampledDataset, expected_samples: list[list[int] | np.ndarray]) -> None: @@ -183,61 +162,7 @@ def validate_indexed_dataset_sampling(sampled: SampledIndexedDataset, expected_s for index in range(sampled._parameters.num_samples) ] token_ids = torch.stack([sampled[i].tokens.tokens for i in range(len(sampled))]).to(torch.int64) - Assert.all_equal(token_ids, validate_samples) if expected_samples is not None: Assert.all_equal(token_ids, expected_samples) return token_ids - - -@config_class(dynamic_type={SampledDatasetConfig: "mock_memmap"}) -class MockGPTMemmapDatasetConfig(IndexedDatasetConfig): - _abstract: typing.ClassVar[bool] = False - num_documents: int | None = Field( - default=None, - desc="Expected number of documents in the dataset.", - hint=FieldHint.core, - ) - num_tokens_per_document: int | None = Field( - default=None, - desc="Expected number of tokens in the dataset.", - hint=FieldHint.optional, - ) - path: pathlib.Path = Field(default=".") - - def build(self) -> "IndexedDataset": - return MockMemmapDataset(self) - - def __len__(self) -> int: - return self.num_documents - - @property - def num_tokens(self) -> int: - return self.num_documents * self.num_tokens_per_document - - -class MockMemmapDataset[SampleType: Sample](IndexedDataset[SampleType]): - def __init__(self, config: MockGPTMemmapDatasetConfig): - self._config = config - - @property - def name(self) -> str: - return "mock_memmap" - - def __len__(self) -> int: - return len(self._config) - - @property - def num_tokens(self) -> int: - return self._config.num_tokens - - def get_document_sizes(self) -> torch.Tensor: - return torch.full([self._config.num_documents], self._config.num_tokens_per_document, dtype=torch.int64) - - def get_document_size(self, index: int) -> int: - return self._config.num_tokens_per_document - - def get_document( - self, index: int, begin: int = 0, end: int | None = None, parameters: SamplingParameters | None = None - ) -> SampleType: - raise NotImplementedError() diff --git a/tests/data/test_blending.py b/tests/data/test_blending.py index b2b2f0117..88ecf2c99 100644 --- a/tests/data/test_blending.py +++ b/tests/data/test_blending.py @@ -12,17 +12,11 @@ get_sampling_data, get_test_data_and_compare_samples, ) -from tests.utils.dataset import get_test_dataset -from tests.utils.global_variables import DATASET_CACHE, DATASET_PATH - -_DATASET_PATH_MIX_1 = DATASET_CACHE / "blended_mix_1" / "dataset" - - -def _get_test_dataset_mix_1(): - return get_test_dataset(_DATASET_PATH_MIX_1, seed=2345) +from tests.utils.dataset import get_alt_test_dataset, get_common_test_dataset def _get_blending_alt(probs: list[float], num_samples: int) -> tuple[np.ndarray, np.ndarray]: + # Alternate implementation for blending. probs = np.array(probs) dataset_index = np.zeros(num_samples) sample_index = np.zeros(num_samples) @@ -37,25 +31,25 @@ def _get_blending_alt(probs: list[float], num_samples: int) -> tuple[np.ndarray, GPT_BLENDED_SAMPLES = [ - [4709, 819, 79, 207, 277, 1790], - [1790, 80, 6506, 1735, 542, 88], - [4628, 7392, 920, 79, 1322, 387], - [88, 4302, 269, 2794, 119, 80], - [80, 207, 567, 498, 89, 207], - [207, 4700, 549, 79, 417, 3036], - [387, 4224, 87, 2713, 423, 324], - [3036, 253, 207, 2968, 4536, 1178], + [49152, 46, 10, 819, 19, 45], + [45, 69, 17, 86, 38826, 15], + [49152, 83, 80, 20452, 45, 93], + [15, 25, 51, 31, 32348, 64], + [64, 17, 93, 78, 40, 1793], + [1793, 1, 1746, 38, 27, 58], + [93, 90, 39, 6, 75, 9], + [58, 22885, 93, 37, 92, 76], ] GPT_BLENDED_MIXED_SAMPLES = [ - [4709, 819, 79, 207, 277, 1790], + [49152, 46, 10, 819, 19, 45], [916, 6683, 7685, 1277, 5106, 378], - [1790, 80, 6506, 1735, 542, 88], + [45, 69, 17, 86, 38826, 15], [3359, 6803, 780, 4561, 669, 7878], - [88, 4302, 269, 2794, 119, 80], - [80, 207, 567, 498, 89, 207], + [15, 25, 51, 31, 32348, 64], + [64, 17, 93, 78, 40, 1793], [6920, 2218, 2921, 3963, 7606, 6904], - [207, 4700, 549, 79, 417, 3036], + [1793, 1, 1746, 38, 27, 58], ] @@ -112,38 +106,21 @@ def test_blending(probs): def test_gpt_blended(): # Make sure dataset blending works and check for unintended changes in behavior. - get_test_dataset() - _get_test_dataset_mix_1() + _, config, _ = get_common_test_dataset() + _, alt_config, _ = get_alt_test_dataset() sampled = get_dataset_config( - { + dataset_config := { "type": "blended", - "datasets": [ - {"type": "memmap", "path": DATASET_PATH}, - {"type": "memmap", "path": _DATASET_PATH_MIX_1}, - ], + "datasets": [config, alt_config], "weights": [0.75, 0.25], }, BlendedDatasetConfig[LanguageModelSample], - ).build_and_sample(get_sampling_data(8, sequence_length=5)) + ).build_and_sample(get_sampling_data(8, sequence_length=5, vocab_size=8192)) compare_sampled_dataset(sampled, GPT_BLENDED_SAMPLES) - -def test_gpt_blended_data(): - get_test_dataset() - _get_test_dataset_mix_1() + # Test in data. get_test_data_and_compare_samples( - { - "datasets": { - "training": { - "type": "blended", - "datasets": [ - {"type": "memmap", "path": DATASET_PATH}, - {"type": "memmap", "path": _DATASET_PATH_MIX_1}, - ], - "weights": [0.75, 0.25], - } - } - }, + {"datasets": {"training": dataset_config}}, 8, sequence_length=5, expected_samples=GPT_BLENDED_SAMPLES, @@ -152,34 +129,25 @@ def test_gpt_blended_data(): def test_gpt_blended_mixed(): # Make sure dataset blending works and check for unintended changes in behavior. - get_test_dataset() + _, config, _ = get_common_test_dataset() sampled = get_dataset_config( - { + dataset_config := { "type": "blended", "datasets": [ - {"type": "memmap", "path": DATASET_PATH}, + config, {"type": "random"}, ], "weights": [0.6, 0.4], }, BlendedDatasetConfig[LanguageModelSample], - ).build_and_sample(get_sampling_data(8, sequence_length=5)) + ).build_and_sample(get_sampling_data(8, sequence_length=5, vocab_size=8192)) compare_sampled_dataset(sampled, GPT_BLENDED_MIXED_SAMPLES) - -def test_gpt_blended_mixed_data(): - get_test_dataset() + # Test in data. get_test_data_and_compare_samples( - { - "datasets": { - "training": { - "type": "blended", - "datasets": [{"type": "memmap", "path": DATASET_PATH}, {"type": "random"}], - "weights": [0.6, 0.4], - } - } - }, + {"datasets": {"training": dataset_config}}, 8, sequence_length=5, + vocab_size=8192, expected_samples=GPT_BLENDED_MIXED_SAMPLES, ) diff --git a/tests/data/test_concatenate.py b/tests/data/test_concatenate.py index 7b009bbf6..d7e750c8b 100644 --- a/tests/data/test_concatenate.py +++ b/tests/data/test_concatenate.py @@ -1,56 +1,48 @@ from fast_llm.data.dataset.config import ConcatenatedDatasetConfig +from fast_llm.data.dataset.gpt.config import GPTDatasetFromFileConfig from fast_llm.data.sample.language_model import LanguageModelSample from tests.data.common import ( - compare_indexed_dataset, + compare_indexed_dataset_tokens, compare_sampled_dataset, get_dataset_config, get_sampling_data, get_test_data_and_compare_samples, ) -from tests.data.test_memmap import MEMMAP_DATASET_LENGTH, MEMMAP_DATASET_SAMPLES, MEMMAP_DATASET_TOKENS -from tests.utils.dataset import get_test_dataset -from tests.utils.global_variables import DATASET_PATH +from tests.data.test_preparator import COMMON_DATASET_LENGTH, COMMON_DATASET_SAMPLES, COMMON_DATASET_TOKENS +from tests.utils.dataset import get_common_test_dataset GPT_CONCATENATED_SAMPLES = [ - [4709, 819, 79, 207, 277, 1790], - [1790, 80, 6506, 1735, 542, 88], - [88, 4302, 269, 2794, 119, 80], - [80, 207, 567, 498, 89, 207], - [207, 4700, 549, 79, 417, 3036], - [3036, 253, 207, 2968, 4536, 1178], - [1178, 3291, 317, 277, 2679, 89], - [89, 542, 395, 583, 684, 554], + [49152, 46, 10, 819, 19, 45], + [45, 69, 17, 86, 38826, 15], + [15, 25, 51, 31, 32348, 64], + [64, 17, 93, 78, 40, 1793], + [1793, 1, 1746, 38, 27, 58], + [58, 22885, 93, 37, 92, 76], + [76, 29, 19, 17365, 93, 46], + [46, 83, 17211, 1, 785, 1023], ] def test_gpt_concatenate(): # Make sure the dataset concatenation works and check for unintended changes in behavior. - get_test_dataset() + _, config, _ = get_common_test_dataset() + memmap_config = GPTDatasetFromFileConfig.from_dict(config)._load_config() dataset = get_dataset_config( - {"type": "concatenated", "datasets": [{"type": "memmap", "path": DATASET_PATH} for _ in range(3)]}, + dataset_config := {"type": "concatenated", "datasets": [memmap_config.to_dict() for _ in range(3)]}, ConcatenatedDatasetConfig[LanguageModelSample], ).build() - compare_indexed_dataset( + compare_indexed_dataset_tokens( dataset, - 3 * MEMMAP_DATASET_LENGTH, - 3 * MEMMAP_DATASET_TOKENS, - {j * MEMMAP_DATASET_LENGTH + i: sample for j in range(3) for i, sample in MEMMAP_DATASET_SAMPLES.items()}, + 3 * COMMON_DATASET_LENGTH, + 3 * COMMON_DATASET_TOKENS, + {j * COMMON_DATASET_LENGTH + i: sample for j in range(3) for i, sample in COMMON_DATASET_SAMPLES.items()}, ) sampled = dataset.sample(get_sampling_data(8, sequence_length=5)) compare_sampled_dataset(sampled, GPT_CONCATENATED_SAMPLES) - -def test_gpt_concatenate_data(): - get_test_dataset() + # Test in data. get_test_data_and_compare_samples( - { - "datasets": { - "training": { - "type": "concatenated", - "datasets": [{"type": "memmap", "path": DATASET_PATH} for _ in range(3)], - } - } - }, + {"datasets": {"training": dataset_config}}, 8, sequence_length=5, expected_samples=GPT_CONCATENATED_SAMPLES, diff --git a/tests/data/test_dataset_from_file.py b/tests/data/test_dataset_from_file.py deleted file mode 100644 index af91df1e2..000000000 --- a/tests/data/test_dataset_from_file.py +++ /dev/null @@ -1,12 +0,0 @@ -from fast_llm.data.dataset.gpt.config import GPTDatasetFromFileConfig -from tests.data.common import compare_indexed_dataset, get_dataset_config -from tests.data.test_memmap import MEMMAP_DATASET_LENGTH, MEMMAP_DATASET_SAMPLES, MEMMAP_DATASET_TOKENS -from tests.utils.dataset import get_test_dataset -from tests.utils.global_variables import DATASET_PATH - - -def test_dataset_from_file(): - get_test_dataset() - dataset_config = {"type": "file", "path": str(DATASET_PATH.parent.joinpath("fast_llm_config.yaml"))} - dataset = get_dataset_config(dataset_config, GPTDatasetFromFileConfig).build() - compare_indexed_dataset(dataset, MEMMAP_DATASET_LENGTH, MEMMAP_DATASET_TOKENS, MEMMAP_DATASET_SAMPLES) diff --git a/tests/data/test_fim.py b/tests/data/test_fim.py index b9dc7fe32..0600c5258 100644 --- a/tests/data/test_fim.py +++ b/tests/data/test_fim.py @@ -5,34 +5,30 @@ get_sampling_data, get_test_data_and_compare_samples, ) -from tests.utils.dataset import get_test_dataset -from tests.utils.global_variables import DATASET_PATH, TOKENIZER_PATH +from tests.utils.dataset import get_common_test_dataset +from tests.utils.global_variables import TOKENIZER_PATH GPT_FIM_SAMPLES = [ - [4709, 819, 79, 207, 277, 1790], - [1790, 80, 6506, 1735, 542, 88], - [86, 89, 7876, 80, 49152, 87], - [80, 207, 567, 498, 89, 207], - [207, 4700, 549, 79, 417, 3036], - [86, 89, 1178, 49152, 87, 49152], - [86, 49152, 1178, 64, 89, 900], - [86, 49152, 89, 542, 395, 89], + [46, 10, 819, 19, 45, 88], + [45, 69, 17, 86, 38826, 15], + [86, 89, 32348, 64, 49152, 87], + [64, 17, 93, 78, 40, 1793], + [1793, 1, 1746, 38, 27, 58], + [86, 89, 37, 92, 76, 49152], + [86, 49152, 76, 29, 19, 89], + [86, 49152, 46, 83, 17211, 1], ] def test_gpt_fim(): # Make sure the FIM wrapper works in a simple case and check for unintended changes in behavior. - get_test_dataset() + _, config, _ = get_common_test_dataset() # The test tokenizer doesn't have fim tokens, so we work around it. - sampling_config = get_sampling_data( - 8, - sequence_length=5, - vocab_size=49157, - ) + sampling_config = get_sampling_data(8, sequence_length=5) sampled = get_dataset_config( - { + dataset_config := { "type": "fim", - "dataset": {"type": "memmap", "path": DATASET_PATH}, + "dataset": config, "tokenizer": {"path": TOKENIZER_PATH}, "rate": 0.5, "prefix_token": "w", @@ -44,26 +40,9 @@ def test_gpt_fim(): ).build_and_sample(sampling_config) compare_sampled_dataset(sampled, GPT_FIM_SAMPLES) - -def test_gpt_fim_data(): - get_test_dataset() get_test_data_and_compare_samples( - { - "datasets": { - "training": { - "type": "fim", - "dataset": {"type": "memmap", "path": DATASET_PATH}, - "tokenizer": {"path": TOKENIZER_PATH}, - "rate": 0.5, - "prefix_token": "w", - "middle_token": "x", - "pad_token": "y", - "suffix_token": "z", - } - }, - }, + {"datasets": {"training": dataset_config}}, 8, sequence_length=5, expected_samples=GPT_FIM_SAMPLES, - vocab_size=49157, ) diff --git a/tests/data/test_loss_masking_spans.py b/tests/data/test_loss_masking_spans.py new file mode 100644 index 000000000..521eaf2a9 --- /dev/null +++ b/tests/data/test_loss_masking_spans.py @@ -0,0 +1,78 @@ +import datasets + +from fast_llm.data.dataset.gpt.config import GPTDatasetFromFileConfig, GPTSamplingParameters +from fast_llm.data.dataset.memmap import MemmapDataset +from fast_llm.data.preprocessing.tokenizer import TokenizerConfig +from fast_llm.data.sample.language_model import LanguageModelSample +from fast_llm.utils import Assert +from tests.data.common import get_dataset_config +from tests.data.test_preparator import COMMON_DATASET_LENGTH, COMMON_DATASET_TEXT +from tests.utils.dataset import get_test_dataset_with_loss_masking_spans +from tests.utils.global_variables import TOKENIZER_NAME + +DATASET_WITH_SPAN_TOKENS = 46199 +DATASET_WITH_SPAN_SAMPLES = { + 27: [49152, 63, 82, 11, 84, 71, 49152], + 30: [49152, 31, 85, 78, 27, 34, 46, 62, 43, 49152], + 31: [49152, 60, 55, 80, 30, 85, 22, 18, 49152], + 77: [49152, 73, 80, 85, 52, 22, 46, 5, 88, 78, 49152], + 87: [49152, 52, 89, 75, 11, 71, 49152], +} +HF_LOSS_MASKING_SPANS = { + 27: [[0, 1], [3, 3]], + 30: [[0, 0], [2, 2], [5, 5]], + 31: [[0, 0], [2, 2], [4, 4]], + 77: [[0, 0], [3, 5], [7, 7]], + 87: [[1, 1], [3, 3]], +} +TOKEN_LOSS_MASKING_SPANS = { + 27: [(1, 3), (4, 5)], + 30: [(1, 2), (3, 4), (6, 7)], + 31: [(1, 2), (3, 4), (5, 6)], + 77: [(1, 2), (4, 7), (8, 9)], + 87: [(2, 3), (4, 5)], +} + + +def test_gpt_data_with_spans(): + _, config, hf_path = get_test_dataset_with_loss_masking_spans() + dataset: MemmapDataset[LanguageModelSample] = get_dataset_config(config, GPTDatasetFromFileConfig).build() + + hf_dataset = datasets.load_from_disk(hf_path)["train"] + tokenizer = TokenizerConfig(path=TOKENIZER_NAME).get_tokenizer() + + # Check global stats. + Assert.eq(len(dataset), len(hf_dataset), COMMON_DATASET_LENGTH) + Assert.eq(dataset.num_tokens, DATASET_WITH_SPAN_TOKENS) + + for index in range(0, 200, 8): + expected_text = hf_dataset[index]["text"] + expected_text_spans = [(begin, last + 1) for begin, last in hf_dataset[index]["loss_masking_spans"]] + expected_tokens, expected_spans = tokenizer.tokenize_with_spans( + hf_dataset[index]["text"], + text_spans=[(begin, last + 1) for begin, last in hf_dataset[index]["loss_masking_spans"]], + ) + document = dataset.get_document( + index, parameters=GPTSamplingParameters(num_samples=0, sequence_length=0, use_loss_masking_spans=True) + ) + + # Compare tokens and token spans. + Assert.all_equal(document.tokens.tokens, expected_tokens) + Assert.eq(document.loss_masking_spans.ranges, expected_spans) + + # Compare text. + text, text_spans = tokenizer.detokenize_with_spans( + document.tokens.tokens, True, True, token_spans=document.loss_masking_spans.ranges + ) + Assert.eq(text, expected_text) + Assert.eq(text_spans, expected_text_spans) + + # Check some numerical values. + for index in DATASET_WITH_SPAN_SAMPLES: + Assert.eq(hf_dataset[index]["text"], COMMON_DATASET_TEXT[index]) + Assert.eq(hf_dataset[index]["loss_masking_spans"], HF_LOSS_MASKING_SPANS[index]) + document = dataset.get_document( + index, parameters=GPTSamplingParameters(num_samples=0, sequence_length=0, use_loss_masking_spans=True) + ) + Assert.all_equal(document.tokens.tokens, DATASET_WITH_SPAN_SAMPLES[index]) + Assert.all_equal(document.loss_masking_spans.ranges, TOKEN_LOSS_MASKING_SPANS[index]) diff --git a/tests/data/test_memmap.py b/tests/data/test_memmap.py deleted file mode 100644 index b11f84d9c..000000000 --- a/tests/data/test_memmap.py +++ /dev/null @@ -1,47 +0,0 @@ -import pathlib - -import pytest - -from fast_llm.data.dataset.config import MemmapDatasetConfig -from tests.data.common import compare_indexed_dataset, get_dataset_config -from tests.utils.dataset import get_test_dataset -from tests.utils.global_variables import DATASET_PATH, DATASET_SAMPLING_CACHE, DATASET_WITH_SPANS_PATH - -MEMMAP_DATASET_LENGTH = 6153 -MEMMAP_DATASET_TOKENS = 508327 -MEMMAP_DATASET_SAMPLES = { - 9: [], - 10: [80, 85, 4295, 4182, 489, 727, 84, 698, 1197, 583], - 13: [78, 727, 74, 317, 1358, 89], - 15: [78], -} - - -@pytest.mark.parametrize("cache_directory", (None, pathlib.Path(DATASET_SAMPLING_CACHE) / "test_memmap")) -def test_gpt_memmap(cache_directory): - # Make sure the memmap dataset works and check for unintended changes in behavior. - get_test_dataset() - dataset = get_dataset_config({"type": "memmap", "path": DATASET_PATH}, MemmapDatasetConfig).build() - compare_indexed_dataset(dataset, MEMMAP_DATASET_LENGTH, MEMMAP_DATASET_TOKENS, MEMMAP_DATASET_SAMPLES) - - -MEMMAP_DATASET_SPANS = { - 9: [], - 10: [(0, 1), (2, 6), (7, 9)], - 13: [(0, 1)], - 15: [], -} - - -def test_gpt_data_with_spans(): - get_test_dataset(DATASET_WITH_SPANS_PATH, max_spans=5) - dataset = get_dataset_config( - { - "type": "memmap", - "path": DATASET_WITH_SPANS_PATH, - }, - MemmapDatasetConfig, - ).build() - compare_indexed_dataset( - dataset, MEMMAP_DATASET_LENGTH, MEMMAP_DATASET_TOKENS, MEMMAP_DATASET_SAMPLES, MEMMAP_DATASET_SPANS - ) diff --git a/tests/data/test_preference_spans.py b/tests/data/test_preference_spans.py new file mode 100644 index 000000000..7b570c5a1 --- /dev/null +++ b/tests/data/test_preference_spans.py @@ -0,0 +1,105 @@ +import datasets +import numpy as np +import torch + +from fast_llm.data.dataset.gpt.config import GPTDatasetFromFileConfig, GPTSamplingParameters +from fast_llm.data.dataset.memmap import MemmapDataset +from fast_llm.data.preprocessing.tokenizer import TokenizerConfig +from fast_llm.data.sample.language_model import LanguageModelSample +from fast_llm.utils import Assert +from tests.data.common import get_dataset_config +from tests.data.test_preparator import COMMON_DATASET_LENGTH +from tests.utils.dataset import get_test_dataset_with_preference_spans +from tests.utils.global_variables import TOKENIZER_NAME + +DATASET_WITH_PREFERENCE_SPAN_TOKENS = 62163 +DATASET_WITH_PREFERENCE_SPAN_TEXT = { + 27: ["`", "s,", "uh"], + 30: ["@v", "o{hf_dataset[index]["answer"]}<|endoftext|>", + ) diff --git a/tests/data/test_prepare_gpt_memmap.py b/tests/data/test_prepare_gpt_memmap.py deleted file mode 100644 index 09a91d6a8..000000000 --- a/tests/data/test_prepare_gpt_memmap.py +++ /dev/null @@ -1,211 +0,0 @@ -import json -import pathlib -import tempfile - -import numpy as np -import pytest -import torch - -from fast_llm.data.dataset.config import IndexedDatasetConfig -from fast_llm.data.dataset.gpt.config import GPTSamplingParameters -from fast_llm.data.dataset.gpt.legacy_memmap import MEMMAP_DTYPES -from fast_llm.data.dataset.memmap import MemmapDataset -from fast_llm.data.preparator.gpt_memmap.config import GPTMemmapDatasetPreparatorConfig -from fast_llm.data.preparator.gpt_memmap.prepare import GPTMemmapDatasetPreparator -from fast_llm.data.sample.language_model import LanguageModelSample, LanguageModelWriter -from fast_llm.data.sample.range import RangeSample -from fast_llm.data.sample.token import TokenSample -from fast_llm.utils import Assert -from tests.data.common import MockGPTMemmapDatasetConfig # Noqa - - -def get_preparator(output_path: str, dataset_path_name: str) -> GPTMemmapDatasetPreparator: - config = GPTMemmapDatasetPreparatorConfig.from_dict( - { - "output_path": output_path, - "dataset": {"path": dataset_path_name}, - "tokenizer": {"path": "no_tokenizer"}, - }, - {}, - ) - return config.get_dataset_preparator_class()(config=config) - - -@pytest.mark.parametrize("dtype", MEMMAP_DTYPES.values()) -def test_write_memmap_dataset(dtype): - documents = [ - LanguageModelSample( - TokenSample(torch.from_numpy(np.random.randint(1000, size=np.random.randint(1, 100)).astype(dtype))) - ) - for _ in range(100) - ] - with tempfile.TemporaryDirectory() as temp_dir: - path = pathlib.Path(temp_dir) / "dataset" - MemmapDataset.write_dataset(path, documents, LanguageModelWriter) - dataset = MemmapDataset("dataset", path) - for i, document in enumerate(documents): - Assert.all_equal(dataset.get_document(i).tokens.tokens, document.tokens.tokens.to(torch.int64)) - - -def _generate_valid_span(max_seq_length) -> tuple[int, int]: - return tuple(np.sort(np.random.choice(np.arange(0, max_seq_length - 1), size=2, replace=False)).tolist()) - - -@pytest.mark.parametrize("dtype", MEMMAP_DTYPES.values()) -def test_write_memmap_preference_dataset(dtype): - documents = [ - LanguageModelSample( - TokenSample(torch.from_numpy(np.random.randint(1000, size=100).astype(dtype))), - None, - RangeSample([_generate_valid_span(100)], 100), - RangeSample([_generate_valid_span(100)], 100), - ) - for _ in range(50) - ] - with tempfile.TemporaryDirectory() as temp_dir: - path = pathlib.Path(temp_dir) / "dataset" - MemmapDataset.write_dataset(path, documents, LanguageModelWriter) - dataset = MemmapDataset("dataset", path) - parameters = GPTSamplingParameters( - num_samples=0, sequence_length=0, vocab_size=0, use_preference_loss_spans=True - ) - for i, document in enumerate(documents): - dataset_document = dataset.get_document(i, parameters=parameters) - Assert.all_equal(dataset_document.tokens.tokens, document.tokens.tokens.to(torch.int64)) - Assert.eq(dataset_document.chosen_spans.ranges, document.chosen_spans.ranges) - Assert.eq(dataset_document.rejected_spans.ranges, document.rejected_spans.ranges) - - -def test_load_metadata_from_hub(): - with tempfile.TemporaryDirectory(suffix="test") as local_folder: - get_preparator(local_folder, "lhoestq/demo1")._save_croissant_metadata() - croissant_path = pathlib.Path(local_folder) / "croissant.json" - assert croissant_path.is_file() - metadata = json.load(croissant_path.open("r")) - assert metadata["url"] == "https://huggingface.co/datasets/lhoestq/demo1" - - -def test_absent_metadata_from_hub(): - with tempfile.TemporaryDirectory(suffix="test") as local_folder: - get_preparator(local_folder, "allenai/dolma")._save_croissant_metadata() - assert not (pathlib.Path(local_folder) / "croissant.json").is_file() - - -def test_load_metadata_local(): - with ( - tempfile.TemporaryDirectory(suffix="dataset") as dataset_folder, - tempfile.TemporaryDirectory(suffix="test") as local_folder, - ): - metadata = {"name": "test"} - json.dump(metadata, (pathlib.Path(dataset_folder) / "croissant.json").open("w")) - get_preparator(local_folder, dataset_folder)._save_croissant_metadata() - croissant_path = pathlib.Path(local_folder) / "croissant.json" - assert croissant_path.is_file() - assert json.loads(croissant_path.open("r").read()) == metadata - - -def test_absent_metadata_local(): - with ( - tempfile.TemporaryDirectory(suffix="dataset") as dataset_folder, - tempfile.TemporaryDirectory(suffix="test") as local_folder, - ): - get_preparator(local_folder, dataset_folder)._save_croissant_metadata() - assert not (pathlib.Path(local_folder) / "croissant.json").is_file() - - -DATASET_DICT_0 = { - "type": "mock_memmap", - "num_documents": 500, - "num_tokens_per_document": 300, -} -DATASET_DICT_1 = { - "type": "mock_memmap", - "num_documents": 1500, - "num_tokens_per_document": 100, -} - - -def test_split_dataset(): - dataset_config_0 = IndexedDatasetConfig[LanguageModelSample].from_dict(DATASET_DICT_0.copy()) - config = GPTMemmapDatasetPreparator._split_and_blend_dataset_configs( - [dataset_config_0], - [dataset_config_0], # Mock reader config. - {"training": 3, "validation": 1}, - pathlib.Path("."), - ) - config = {key: value.to_dict() for key, value in config.items()} - - Assert.eq( - config, - { - "training": { - "type": "slice", - "dataset": dataset_config_0.to_dict(), - "begin": 0, - "end": 0.75, - }, - "validation": { - "type": "slice", - "dataset": dataset_config_0.to_dict(), - "begin": 0.75, - "end": 1, - }, - }, - ) - - -def test_split_datasets_0(): - dataset_config_0 = IndexedDatasetConfig[LanguageModelSample].from_dict(DATASET_DICT_0.copy()) - dataset_config_1 = IndexedDatasetConfig[LanguageModelSample].from_dict(DATASET_DICT_1.copy()) - config = GPTMemmapDatasetPreparator._split_and_blend_dataset_configs( - [dataset_config_0, dataset_config_1], - [dataset_config_0, dataset_config_1], # Mock reader configs. - {"training": 1, "validation": 1}, - pathlib.Path("."), - ) - config = {key: value.to_dict() for key, value in config.items()} - - Assert.eq( - config, - { - "training": dataset_config_0.to_dict(), - "validation": dataset_config_1.to_dict(), - }, - ) - - -def test_split_datasets_1(): - dataset_config_0 = IndexedDatasetConfig[LanguageModelSample].from_dict(DATASET_DICT_0.copy()) - dataset_config_1 = IndexedDatasetConfig[LanguageModelSample].from_dict(DATASET_DICT_1.copy()) - config = GPTMemmapDatasetPreparator._split_and_blend_dataset_configs( - [dataset_config_0, dataset_config_1], - [dataset_config_0, dataset_config_1], # Mock reader configs. - {"training": 3, "validation": 1}, - pathlib.Path("."), - ) - config = {key: value.to_dict() for key, value in config.items()} - - Assert.eq( - config, - { - "training": { - "type": "blended", - "datasets": [ - dataset_config_0.to_dict(), - { - "type": "slice", - "dataset": dataset_config_1.to_dict(), - "begin": 0, - "end": 0.5, - }, - ], - "weights": [2 / 3, 1 / 3], - }, - "validation": { - "type": "slice", - "dataset": dataset_config_1.to_dict(), - "begin": 0.5, - "end": 1, - }, - }, - ) diff --git a/tests/data/test_random.py b/tests/data/test_random.py index 8e5c61904..7a31358b9 100644 --- a/tests/data/test_random.py +++ b/tests/data/test_random.py @@ -16,22 +16,16 @@ def test_gpt_random_dataset(): # Make sure the random dataset works and check for unintended changes in behavior. - sampled = get_dataset_config({"type": "random"}, GPTRandomDatasetConfig).build_and_sample( - get_sampling_data(4, sequence_length=7) + sampled = get_dataset_config(config := {"type": "random"}, GPTRandomDatasetConfig).build_and_sample( + get_sampling_data(4, sequence_length=7, vocab_size=8192) ) compare_sampled_dataset(sampled, RANDOM_DATASET_EXPECTED_SAMPLES) - -def test_gpt_random_data(): + # Test in data. get_test_data_and_compare_samples( - { - "datasets": { - "training": { - "type": "random", - } - } - }, + {"datasets": {"training": config}}, 4, sequence_length=7, + vocab_size=8192, expected_samples=RANDOM_DATASET_EXPECTED_SAMPLES, ) diff --git a/tests/data/test_sampling.py b/tests/data/test_sampling.py index c171d15dd..2d102be01 100644 --- a/tests/data/test_sampling.py +++ b/tests/data/test_sampling.py @@ -2,8 +2,8 @@ import pytest import torch -from fast_llm.data.dataset.config import MemmapDatasetConfig, ShufflingType -from fast_llm.data.dataset.gpt.config import GPTSamplingParameters +from fast_llm.data.dataset.config import ShufflingType +from fast_llm.data.dataset.gpt.config import GPTDatasetFromFileConfig, GPTSamplingParameters from fast_llm.data.dataset.indexed import IndexedDataset from fast_llm.data.sample.language_model import LanguageModelSample from fast_llm.data.sample.token import TokenSample @@ -14,8 +14,7 @@ get_test_data_and_compare_samples, validate_indexed_dataset_sampling, ) -from tests.utils.dataset import get_test_dataset -from tests.utils.global_variables import DATASET_PATH +from tests.utils.dataset import get_common_test_dataset try: from fast_llm.csrc.data import build_padded_token_cumsum # noqa @@ -26,37 +25,28 @@ GPT_MEMMAP_SAMPLES = [ - [4709, 819, 79, 207, 277, 1790], - [1790, 80, 6506, 1735, 542, 88], - [88, 4302, 269, 2794, 119, 80], - [80, 207, 567, 498, 89, 207], - [207, 4700, 549, 79, 417, 3036], - [3036, 253, 207, 2968, 4536, 1178], - [1178, 3291, 317, 277, 2679, 89], - [89, 542, 395, 583, 684, 554], + [49152, 46, 10, 819, 19, 45], + [45, 69, 17, 86, 38826, 15], + [15, 25, 51, 31, 32348, 64], + [64, 17, 93, 78, 40, 1793], + [1793, 1, 1746, 38, 27, 58], + [58, 22885, 93, 37, 92, 76], + [76, 29, 19, 17365, 93, 46], + [46, 83, 17211, 1, 785, 1023], ] def test_gpt_sampled(): # Make sure the memmap dataset works and check for unintended changes in behavior. - get_test_dataset() - sampled = get_dataset_config({"type": "memmap", "path": DATASET_PATH}, MemmapDatasetConfig).build_and_sample( - get_sampling_data(8, sequence_length=5) - ) + _, config, _ = get_common_test_dataset() + sampled = get_dataset_config( + dataset_config := config, GPTDatasetFromFileConfig[LanguageModelSample] + ).build_and_sample(get_sampling_data(8, sequence_length=5)) validate_indexed_dataset_sampling(sampled, GPT_MEMMAP_SAMPLES) - -def test_gpt_sampled_data(): - get_test_dataset() + # Test in data. get_test_data_and_compare_samples( - { - "datasets": { - "training": { - "type": "memmap", - "path": DATASET_PATH, - } - } - }, + {"datasets": {"training": dataset_config}}, 8, sequence_length=5, expected_samples=GPT_MEMMAP_SAMPLES, @@ -169,7 +159,6 @@ def test_gpt_sample_padding(): sampling = get_sampling_data( num_samples=len(expected_samples), sequence_length=sequence_length, - vocab_size=vocab_size, seed=seed, shuffle=ShufflingType.disabled, truncate_documents=False, diff --git a/tests/data/test_slice.py b/tests/data/test_slice.py index 3a6b999cd..224b18270 100644 --- a/tests/data/test_slice.py +++ b/tests/data/test_slice.py @@ -1,67 +1,67 @@ from fast_llm.data.dataset.config import DatasetSliceConfig +from fast_llm.data.dataset.gpt.config import GPTDatasetFromFileConfig from fast_llm.data.sample.language_model import LanguageModelSample from tests.data.common import ( - compare_indexed_dataset, + compare_indexed_dataset_tokens, get_dataset_config, get_sampling_data, get_test_data_and_compare_samples, validate_indexed_dataset_sampling, ) -from tests.data.test_memmap import MEMMAP_DATASET_SAMPLES -from tests.utils.dataset import get_test_dataset -from tests.utils.global_variables import DATASET_PATH +from tests.data.test_preparator import COMMON_DATASET_SAMPLES +from tests.utils.dataset import get_common_test_dataset GPT_SLICE_TRAINING_SAMPLES = [ - [80, 268, 79, 260, 207, 3086], - [3086, 80, 413, 4872, 4602, 207], - [207, 7208, 1489, 776, 3514, 269], - [269, 73, 7367, 267, 477, 3126], + [49152, 20, 59, 81, 15, 54], + [54, 76, 7909, 44, 41, 1], + [1, 71, 28, 10, 42, 15963], + [15963, 80, 59, 86, 4, 74], ] GPT_SLICE_VALIDATION_SAMPLES = [ - [1886, 317, 5621, 3173, 330, 284], - [284, 2846, 706, 89, 80, 2047], - [2047, 207, 2449, 1423, 65, 985], - [985, 683, 4917, 87, 477, 481], - [481, 695, 947, 5871, 2344, 87], - [87, 489, 207, 489, 269, 356], - [356, 727, 7800, 4078, 243, 3712], - [3712, 86, 476, 80, 2547, 7390], + [49152, 3, 5621, 27, 7859, 13009], + [13009, 73, 32, 29, 32, 3], + [3, 89, 15, 45, 25, 75], + [75, 52, 13366, 88, 54, 19], + [19, 2, 74, 23, 92, 24747], + [24747, 42, 6, 477, 21, 47], + [47, 92, 31, 30, 463, 64], + [64, 23, 11, 56, 23555, 85], ] def test_gpt_slice(): # Make sure dataset splitting works and check for unintended changes in behavior. - get_test_dataset() + _, config, _ = get_common_test_dataset() + memmap_config = GPTDatasetFromFileConfig.from_dict(config)._load_config() # samples[9:18] dataset = get_dataset_config( - {"type": "slice", "dataset": {"type": "memmap", "path": DATASET_PATH}, "begin": 0.0015, "end": 0.003}, + {"type": "slice", "dataset": memmap_config, "begin": 0.025, "end": 0.1}, DatasetSliceConfig[LanguageModelSample], ).build() - compare_indexed_dataset(dataset, 9, 544, {i - 9: sample for i, sample in MEMMAP_DATASET_SAMPLES.items()}) + compare_indexed_dataset_tokens(dataset, 75, 3399, {i - 25: sample for i, sample in COMMON_DATASET_SAMPLES.items()}) sampled = dataset.sample(get_sampling_data(8, sequence_length=5)) validate_indexed_dataset_sampling(sampled, GPT_SLICE_VALIDATION_SAMPLES) - -def test_gpt_slice_data(): + # Test in data with multiple phases. get_test_data_and_compare_samples( { "datasets": { "training": { "type": "slice", - "dataset": {"type": "memmap", "path": DATASET_PATH}, + "dataset": memmap_config, "begin": 0, - "end": 0.0015, + "end": 0.025, }, "validation": { "type": "slice", - "dataset": {"type": "memmap", "path": DATASET_PATH}, - "begin": 0.0015, - "end": 0.003, + "dataset": memmap_config, + "begin": 0.025, + "end": 0.1, }, "test": { "type": "slice", - "dataset": {"type": "memmap", "path": DATASET_PATH}, - "begin": 0.003, + "dataset": memmap_config, + "begin": 0.1, "end": 1, }, } diff --git a/tests/functional/test_functional.py b/tests/functional/test_functional.py index 489f5e1c1..3a90745eb 100644 --- a/tests/functional/test_functional.py +++ b/tests/functional/test_functional.py @@ -1,3 +1,4 @@ +import numpy as np import pytest import torch @@ -61,12 +62,12 @@ def reference_dpo_loss( def test_dpo_loss(): - torch.manual_seed(0) - logits = torch.randn((10, 50, 100), requires_grad=True) - reference_model_logits = torch.randn((10, 50, 100)) - targets = torch.randint(0, 100, (10, 50)) + random_state = np.random.RandomState(0) + logits = torch.from_numpy(random_state.normal(size=(10, 50, 100))).to(torch.float32).requires_grad_() + reference_model_logits = torch.from_numpy(random_state.normal(size=(10, 50, 100))).to(torch.float32) + targets = torch.from_numpy(random_state.randint(0, 100, (10, 50))) - spans = get_random_spans(10, 10, 50) + spans = get_random_spans(10, 10, 50, random_state) fastllm_loss, fast_llm_grad = compute_dpo_loss( logits, targets, reference_model_logits, spans[::2], spans[1::2], beta=1, grad_output=1 diff --git a/tests/models/test_match_megatron.py b/tests/models/test_match_megatron.py index 4b057dabd..f3ce65966 100644 --- a/tests/models/test_match_megatron.py +++ b/tests/models/test_match_megatron.py @@ -3,8 +3,10 @@ import struct import typing +import datasets import numpy as np import pytest +import torch import yaml from fast_llm.config import Field, FieldHint, config_class @@ -13,13 +15,15 @@ from fast_llm.data.dataset.gpt.config import GPTSamplingData from fast_llm.data.dataset.gpt.legacy_memmap import MEMMAP_DTYPES, MEMMAP_INDEX_HEADER, LegacyMemmapDataset from fast_llm.data.dataset.sampled import logger +from fast_llm.data.preprocessing.tokenizer import TokenizerConfig from fast_llm.data.sample.language_model import LanguageModelSample +from fast_llm.data.sample.token import TokenSample from fast_llm.engine.config_utils.data_type import DataType from fast_llm.utils import Assert from tests.utils.compare_tensor_logs import CompareConfig -from tests.utils.dataset import get_test_dataset_samples +from tests.utils.dataset import get_common_test_dataset from tests.utils.distributed_configs import DistributedTestingConfig -from tests.utils.global_variables import DATASET_CACHE, MODEL_TEST_VOCAB_SIZE +from tests.utils.global_variables import DATASET_CACHE, MODEL_TEST_VOCAB_SIZE, TOKENIZER_NAME from tests.utils.model_configs import ModelTestingGroup from tests.utils.utils import requires_cuda @@ -39,7 +43,17 @@ def get_megatron_test_dataset(prefix: pathlib.Path = MEGATRON_DATASET_PREFIX): and prefix.with_suffix(".bin").is_file() and prefix.parent.joinpath("fast_llm_config.yaml").is_file() ): - MegatronMemmapDataset.write_dataset(prefix, get_test_dataset_samples(vocab_size=MODEL_TEST_VOCAB_SIZE)) + _, _, hf_path = get_common_test_dataset() + hf_dataset = datasets.load_from_disk(hf_path)["train"] + tokenizer = TokenizerConfig(path=TOKENIZER_NAME).get_tokenizer() + samples = [ + LanguageModelSample( + TokenSample((tokenizer.tokenize(document["text"]) % MODEL_TEST_VOCAB_SIZE).to(torch.uint16)) + ) + for document in hf_dataset + ] + + MegatronMemmapDataset.write_dataset(prefix, samples) yaml.safe_dump( {"type": "memmap", "path": prefix.name}, prefix.parent.joinpath("fast_llm_config.yaml").open("w") ) diff --git a/tests/utils/dataset.py b/tests/utils/dataset.py index 7f2c9290a..b21bda1ea 100644 --- a/tests/utils/dataset.py +++ b/tests/utils/dataset.py @@ -1,24 +1,12 @@ import pathlib -import random +import typing +import datasets import numpy as np -import torch -import yaml - -from fast_llm.data.dataset.memmap import MemmapDataset -from fast_llm.data.sample.language_model import LanguageModelSample, LanguageModelWriter -from fast_llm.data.sample.range import RangeSample -from fast_llm.data.sample.token import TokenSample -from tests.utils.global_variables import ( - DATASET_PATH, - MODEL_DATASET_PATH, - MODEL_TEST_VOCAB_SIZE, - TEST_CHARACTERS, - TEST_DATASET_TOKENS, - TEST_VOCAB_SIZE, - TOKENIZER_FILE, - TOKENIZER_PATH, -) + +from fast_llm.data.preparator.gpt_memmap.config import GPTMemmapDatasetPreparatorConfig +from fast_llm.utils import padded_cumsum +from tests.utils.global_variables import DATASET_CACHE, MODEL_TEST_VOCAB_SIZE, TOKENIZER_FILE, TOKENIZER_PATH def download_santacoder_tokenizer(): @@ -28,69 +16,165 @@ def download_santacoder_tokenizer(): transformers.AutoTokenizer.from_pretrained("bigcode/santacoder").save_pretrained(TOKENIZER_PATH) -def get_random_spans(num_samples: int, max_spans: int, lengths: np.ndarray | int, seed: int = 0): - spans = np.sort(np.random.RandomState(seed + 3847).randint(0, lengths, [num_samples, max_spans * 2])) - spans = [np.unique(sample_spans).tolist() for sample_spans in spans] +def get_random_spans( + num_documents: int, + max_spans: int, + lengths: np.ndarray | int, + random_state: np.random.RandomState = np.random, + use_last_format: bool = False, + variable_length: bool = True, +): + if variable_length: + spans = random_state.randint( + 0, lengths[:, None] if isinstance(lengths, np.ndarray) else lengths, [num_documents, max_spans * 2] + ) + else: + spans = [ + random_state.choice(range(length), max_spans * 2, replace=False) + for length in (lengths if isinstance(lengths, np.ndarray) else (lengths for _ in range(num_documents))) + ] + spans = [np.unique(sample_spans).tolist() for sample_spans in np.sort(spans)] return [ - [(begin, end) for begin, end in zip(sample_spans[::2], sample_spans[1::2], strict=False)] + [(begin, end - use_last_format) for begin, end in zip(sample_spans[::2], sample_spans[1::2], strict=False)] for sample_spans in spans ] -def get_test_dataset_samples( - seed: int = 1234, - num_tokens: int = TEST_DATASET_TOKENS, - characters: str = TEST_CHARACTERS, - vocab_size: int = TEST_VOCAB_SIZE, - max_spans: int = 0, -) -> list[LanguageModelSample]: - import transformers +def get_random_preference_spans(texts, random_state: np.random.RandomState = np.random) -> dict[str, str]: + texts_ = [] + chosen_spans = [] + rejected_spans = [] + for text in texts: + # Split in three non-empty_chunks + splits = np.sort(random_state.choice(range(1, len(text) - 1), 2, replace=False)).tolist() + texts_.append(text[: splits[0]]) + chosen_spans.append(text[splits[0] : splits[1]]) + rejected_spans.append(text[splits[1] :]) + return {"text": texts_, "chosen_span": chosen_spans, "rejected_span": rejected_spans} - download_santacoder_tokenizer() - texts = "".join(random.Random(seed).choices(characters, k=num_tokens)).splitlines() - tokenizer = transformers.AutoTokenizer.from_pretrained(TOKENIZER_PATH) +def _get_hf_test_dataset( + seed: int = 1234, + num_documents: int = 1000, + min_document_size: int = 5, + max_document_size: int = 100, + max_loss_masking_spans: int = 0, + has_preference_spans: bool = False, +): + random_state = np.random.RandomState(seed) + # Generate random document sizes (character count). + document_sizes = random_state.randint(min_document_size, max_document_size, num_documents) + size_cumsums = padded_cumsum(document_sizes) + # Generate random ascii characters. + random_text = random_state.randint(32, 127, document_sizes.sum(), dtype=np.uint8).tobytes().decode() + texts = [random_text[begin:end] for begin, end in zip(size_cumsums[:-1], size_cumsums[1:])] + + if has_preference_spans: + dataset_dict = get_random_preference_spans(texts, random_state) + else: + dataset_dict: dict[str, typing.Any] = {"text": texts} + + if max_loss_masking_spans > 0: + dataset_dict["loss_masking_spans"] = get_random_spans( + num_documents, max_loss_masking_spans, document_sizes, random_state, use_last_format=True + ) - samples = [ - LanguageModelSample( - TokenSample(torch.from_numpy(np.array(tokenizer(document)["input_ids"], dtype=np.uint16) % vocab_size)), + return datasets.Dataset.from_dict(dataset_dict) + + +def _get_test_dataset( + path: pathlib.Path, + seed: int, + tokenizer_path: str = TOKENIZER_PATH, + vocab_size: int | None = None, + documents_per_shard: int = 10**6, + num_documents: int = 1000, + min_document_size: int = 5, + max_document_size: int = 100, + max_loss_masking_spans: int = 0, + has_preference_spans: bool = False, + splits: dict[str, float] | None = None, +): + config_paths = ( + [path / "fast_llm_config.yaml"] + if splits is None + else [path / f"fast_llm_config_{split}.yaml" for split in splits] + ) + hf_path = path / "hf" + + if not (path.is_file() and all(config_path.is_file() for config_path in config_paths)): + dataset = _get_hf_test_dataset( + seed, num_documents, min_document_size, max_document_size, max_loss_masking_spans, has_preference_spans ) - for document in texts - ] - if max_spans > 0: - spans = get_random_spans( - len(samples), max_spans, np.array([[max(len(sample), 1)] for sample in samples]), seed + datasets.DatasetDict({"train": dataset}).save_to_disk(hf_path) + source_schema = {"text": "text"} + if max_loss_masking_spans > 0: + source_schema["loss_masking_spans"] = "loss_masking_spans" + if has_preference_spans: + source_schema["chosen_span"] = "chosen_span" + source_schema["rejected_span"] = "rejected_span" + + download_santacoder_tokenizer() + preparator_config = GPTMemmapDatasetPreparatorConfig.from_dict( + { + "dataset": { + "path": hf_path, + "load_from_disk": True, + "source_schema": source_schema, + }, + "tokenizer": {"path": tokenizer_path, "max_vocab_size": vocab_size}, + "output_path": path, + "documents_per_shard": documents_per_shard, + "splits": splits, + } ) - for sample, sample_spans in zip(samples, spans, strict=True): - sample.loss_masking_spans = RangeSample(sample_spans, len(sample)) - return samples + preparator_config.run() + config = ( + {"type": "file", "path": config_paths[0]} + if splits is None + else { + split: {"type": "file", "path": config_path} + for split, config_path in zip(splits, config_paths, strict=True) + } + ) + return path, config, hf_path -def get_test_dataset( - path: pathlib.Path = DATASET_PATH, - seed: int = 1234, - num_tokens: int = TEST_DATASET_TOKENS, - characters: str = TEST_CHARACTERS, - vocab_size: int = TEST_VOCAB_SIZE, - max_spans: int = 0, -): - config_path = path.parent.joinpath("fast_llm_config.yaml") - - if not (path.is_file() and config_path.is_file()): - samples = get_test_dataset_samples( - seed=seed, - num_tokens=num_tokens, - characters=characters, - vocab_size=vocab_size, - max_spans=max_spans, - ) - MemmapDataset.write_dataset(path, samples, LanguageModelWriter) - yaml.safe_dump({"type": "memmap", "path": path.name}, config_path.open("w")) +def get_common_test_dataset(): + return _get_test_dataset(DATASET_CACHE / "common_dataset", seed=1234) -def get_model_test_dataset( - path: pathlib.Path = MODEL_DATASET_PATH, - vocab_size: int = MODEL_TEST_VOCAB_SIZE, -): - return get_test_dataset(path, vocab_size=vocab_size) +def get_alt_test_dataset(): + return _get_test_dataset(DATASET_CACHE / "other_dataset", seed=2345) + + +def get_sharded_test_dataset(): + return _get_test_dataset(DATASET_CACHE / "common_dataset_sharded", seed=1234, documents_per_shard=350) + + +def get_split_test_dataset(): + return _get_test_dataset( + DATASET_CACHE / "common_dataset_split", seed=1234, splits={"training": 1, "validation": 1} + ) + + +def get_split_sharded_test_dataset(): + return _get_test_dataset( + DATASET_CACHE / "common_dataset_split_sharded", + seed=1234, + documents_per_shard=350, + splits={"training": 1, "validation": 1}, + ) + + +def get_test_dataset_with_loss_masking_spans(): + return _get_test_dataset(DATASET_CACHE / "dataset_with_loss_masking_spans", seed=1234, max_loss_masking_spans=5) + + +def get_test_dataset_with_preference_spans(): + return _get_test_dataset(DATASET_CACHE / "dataset_with_preference_spans", seed=1234, has_preference_spans=True) + + +def get_model_test_dataset(): + return _get_test_dataset(DATASET_CACHE / "model_dataset", seed=1234, vocab_size=MODEL_TEST_VOCAB_SIZE) diff --git a/tests/utils/global_variables.py b/tests/utils/global_variables.py index ea770be0a..20a0c7219 100644 --- a/tests/utils/global_variables.py +++ b/tests/utils/global_variables.py @@ -5,7 +5,6 @@ import os import pathlib -import string from fast_llm.utils import set_global_variables @@ -36,14 +35,11 @@ def set_testing_global_variables(): # TODO: Fixtures TOKENIZER_PATH = SHARED_RESULT_PATH / "tokenizer" TOKENIZER_FILE = TOKENIZER_PATH / "tokenizer.json" +TOKENIZER_NAME = "bigcode/santacoder" + DATASET_CACHE = SHARED_RESULT_PATH / "dataset" -DATASET_PATH = DATASET_CACHE / "common_dataset/dataset.fast_llm_dataset" -DATASET_WITH_SPANS_PATH = DATASET_CACHE / "dataset_with_spans/dataset.fast_llm_dataset" -DATASET_SAMPLING_CACHE = TEST_RESULTS_PATH / "dataset_sampling_cache" -TEST_VOCAB_SIZE = 8192 -# Random lowercase: 80.7% (3.1% each); space: 18.6%; doc end: 0.6% -TEST_CHARACTERS = (string.ascii_lowercase) * 5 + " " * 30 + "\n" -TEST_DATASET_TOKENS = 1000000 -MODEL_DATASET_PATH = DATASET_CACHE / "model_dataset/dataset.fast_llm_dataset" +MODEL_DATASET_SHARD_PATH = DATASET_CACHE / "model_dataset/shard_0_0.fast_llm_dataset" + +DATASET_SAMPLING_CACHE = TEST_RESULTS_PATH / "dataset_sampling_cache" MODEL_TEST_VOCAB_SIZE = 384 diff --git a/tests/utils/model_configs.py b/tests/utils/model_configs.py index ee9c2b730..956aaea5a 100644 --- a/tests/utils/model_configs.py +++ b/tests/utils/model_configs.py @@ -22,7 +22,7 @@ Qwen2CheckpointFormat, ) from tests.utils.distributed_configs import DistributedTestingConfig -from tests.utils.global_variables import MODEL_DATASET_PATH, MODEL_TEST_VOCAB_SIZE +from tests.utils.global_variables import MODEL_DATASET_SHARD_PATH, MODEL_TEST_VOCAB_SIZE from fast_llm.engine.evaluation.evaluators import ( # isort:skip # needed for dynamic type registration EvaluatorsConfig, @@ -234,18 +234,18 @@ def _update_and_add_testing_config( "data": { "datasets": { "training": { - "dataset": {"type": "memmap", "path": MODEL_DATASET_PATH}, + "dataset": {"type": "memmap", "path": MODEL_DATASET_SHARD_PATH}, "type": "slice", "end": 0.969, }, "validation": { - "dataset": {"type": "memmap", "path": MODEL_DATASET_PATH}, + "dataset": {"type": "memmap", "path": MODEL_DATASET_SHARD_PATH}, "type": "slice", "begin": 0.969, "end": 0.999, }, "test": { - "dataset": {"type": "memmap", "path": MODEL_DATASET_PATH}, + "dataset": {"type": "memmap", "path": MODEL_DATASET_SHARD_PATH}, "type": "slice", "begin": 0.999, "end": 1, From 435d21491acb8357a8b49377c7c809e8b8d703d1 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Tue, 4 Nov 2025 21:32:10 -0500 Subject: [PATCH 06/12] fix --- fast_llm/data/tokenizer.py | 88 -------------------------------------- 1 file changed, 88 deletions(-) delete mode 100644 fast_llm/data/tokenizer.py diff --git a/fast_llm/data/tokenizer.py b/fast_llm/data/tokenizer.py deleted file mode 100644 index 71219a2bf..000000000 --- a/fast_llm/data/tokenizer.py +++ /dev/null @@ -1,88 +0,0 @@ -import numpy as np -import torch -from transformers import AutoTokenizer - -from fast_llm.data.config import TokenizerConfig -from fast_llm.engine.config_utils.run import log_main_rank -from fast_llm.utils import Assert - - -class Tokenizer: - """ - A wrapper around Huggingface (transformers) tokenizer. - """ - - def __init__(self, config: TokenizerConfig): - log_main_rank(f"> loading tokenizer from {config.path} ...") - self.tokenizer = AutoTokenizer.from_pretrained( - pretrained_model_name_or_path=config.path, - errors="replace", - max_len=None, - trust_remote_code=True, - use_fast=True, - ) - if config.bos_token is not None: - self.tokenizer.bos_token = config.bos_token - if self.tokenizer.eos_token_id is None: - raise ValueError("Tokenizer does not have an EOS token.") - if self.tokenizer.bos_token_id is None: - raise ValueError("Tokenizer does not have an BOS token.") - self.eod_id = self.tokenizer.eos_token_id - self.bod_id = self.tokenizer.bos_token_id - - @property - def vocab_size(self) -> int: - return len(self.tokenizer) - - @property - def vocab(self) -> dict[str, int]: - return self.tokenizer.vocab - - @property - def inv_vocab(self) -> dict[int, str]: - return self._inv_vocab - - def tokenize(self, text: str, begin: bool = True, end: bool = True) -> list[int]: - return ( - ([self.bod_id] if begin else []) - + self.tokenizer.encode(text, add_special_tokens=False) - + ([self.eod_id] if end else []) - ) - - def tokenize_with_spans( - self, text: str, begin: bool = True, end: bool = True, *, spans: list[tuple[int, int]] - ) -> tuple[list[int], list[tuple[int, int]]]: - """ - Perform span-aware tokenization and return the tokenized input_ids along with token spans. - """ - if not spans: - return self.tokenize(text, begin, end), [] - input_ids, token_splits = self.tokenize_with_splits( - text, begin, end, text_splits=[split for splits in spans for split in splits] - ) - return input_ids, [(begin, end) for begin, end in zip(token_splits[::2], token_splits[1::2], strict=True)] - - def tokenize_with_splits( - self, text: str, begin: bool = True, end: bool = True, *, text_splits: list[int] - ) -> tuple[list[int], list[int]]: - Assert.eq(sorted(text_splits), text_splits) - input_ids = [] - text_splits = [0, *text_splits, len(text_splits)] - token_splits = [] - - for split_begin, split_end in zip(text_splits[:-1], text_splits[1:]): - input_ids.extend( - self.tokenize( - text[split_begin:split_end], begin=begin and split_begin == 0, end=end and split_end == len(text) - ) - ) - token_splits.append(len(input_ids)) - - return input_ids, token_splits[:-1] - - def detokenize(self, token_ids: int | list[int] | np.ndarray | torch.Tensor) -> str: - return self.tokenizer.decode(token_ids) - - @property - def eod(self): - return self.eod_id From f6bef55fb25d4c0c85f3bde2763e2ec55baaf416 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Wed, 5 Nov 2025 19:22:49 -0500 Subject: [PATCH 07/12] fix --- tests/utils/dataset.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/utils/dataset.py b/tests/utils/dataset.py index b21bda1ea..ba19916ee 100644 --- a/tests/utils/dataset.py +++ b/tests/utils/dataset.py @@ -102,7 +102,7 @@ def _get_test_dataset( ) hf_path = path / "hf" - if not (path.is_file() and all(config_path.is_file() for config_path in config_paths)): + if not all(config_path.is_file() for config_path in config_paths): dataset = _get_hf_test_dataset( seed, num_documents, min_document_size, max_document_size, max_loss_masking_spans, has_preference_spans ) From e05d9a1d0bc6fbfcdb7229982623653d7f1a7082 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Wed, 5 Nov 2025 19:47:25 -0500 Subject: [PATCH 08/12] fix --- fast_llm/data/auto.py | 12 ++++++++++++ fast_llm/models/auto.py | 1 + 2 files changed, 13 insertions(+) diff --git a/fast_llm/data/auto.py b/fast_llm/data/auto.py index c44e538fa..22ab3d731 100644 --- a/fast_llm/data/auto.py +++ b/fast_llm/data/auto.py @@ -2,4 +2,16 @@ Import these submodules to ensure classes are added to the dynamic class registry. """ +from fast_llm.data.dataset.config import ( # isort: skip + BlendedDatasetConfig, + ConcatenatedDatasetConfig, + DatasetSliceConfig, + MemmapDatasetConfig, + SampledDatasetUpdateConfig, +) +from fast_llm.data.dataset.gpt.config import ( # isort: skip + GPTDatasetFromFileConfig, + GPTFimSampledDatasetConfig, + GPTRandomDatasetConfig, +) from fast_llm.data.preparator.gpt_memmap.config import GPTMemmapDatasetPreparatorConfig # isort: skip diff --git a/fast_llm/models/auto.py b/fast_llm/models/auto.py index 322932664..414314627 100644 --- a/fast_llm/models/auto.py +++ b/fast_llm/models/auto.py @@ -2,6 +2,7 @@ Import these submodules to ensure classes are added to the dynamic class registry. """ +from fast_llm.layers.attention.config import AttentionConfig # isort: skip from fast_llm.layers.ssm.config import MambaConfig, Mamba2Config, DiscreteMamba2Config # isort: skip from fast_llm.models.gpt.config import GPTModelConfig, GPTTrainerConfig # isort: skip from fast_llm.engine.evaluation.evaluators import EvaluatorsConfig # isort: skip From 9ba8d1bb6aaf7008cf5d5d9ded24ae19b9795233 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Wed, 5 Nov 2025 19:59:34 -0500 Subject: [PATCH 09/12] fix --- fast_llm/data/dataset/gpt/config.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/fast_llm/data/dataset/gpt/config.py b/fast_llm/data/dataset/gpt/config.py index 7583345c3..2334d1173 100644 --- a/fast_llm/data/dataset/gpt/config.py +++ b/fast_llm/data/dataset/gpt/config.py @@ -9,12 +9,12 @@ from fast_llm.data.dataset.abstract import SamplableDataset, SampledDataset from fast_llm.data.dataset.config import SamplableDatasetConfig, SampledDatasetConfig, SamplingData, SamplingParameters from fast_llm.data.preprocessing.tokenizer import TokenizerConfig -from fast_llm.data.sample.language_model import LanguageModelSample from fast_llm.utils import Assert if typing.TYPE_CHECKING: from fast_llm.data.dataset.gpt.fim import GPTFimDataset from fast_llm.data.dataset.gpt.random import GPTRandomDataset + from fast_llm.data.sample.language_model import LanguageModelSample @dataclasses.dataclass(kw_only=True) From b35b297678bba8672e9c81ae52e135ccbd6382eb Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Thu, 6 Nov 2025 02:40:20 -0500 Subject: [PATCH 10/12] fixes --- fast_llm/data/dataset/gpt/config.py | 1 + tests/data/test_loss_masking_spans.py | 34 +++++++------ tests/data/test_preference_spans.py | 6 ++- tests/data/test_preparator.py | 2 +- tests/utils/dataset.py | 70 ++++++++++++++++++--------- 5 files changed, 70 insertions(+), 43 deletions(-) diff --git a/fast_llm/data/dataset/gpt/config.py b/fast_llm/data/dataset/gpt/config.py index 2334d1173..8dd4098a3 100644 --- a/fast_llm/data/dataset/gpt/config.py +++ b/fast_llm/data/dataset/gpt/config.py @@ -25,6 +25,7 @@ class GPTSamplingParameters(SamplingParameters): # TODO: Only used for random dataset. Remove? Or use as safety check? vocab_size: int | None = None + # TODO: ====== Get these to memmap dataset (currently ignored) ====== use_loss_masking_spans: bool = False use_preference_loss_spans: bool = False diff --git a/tests/data/test_loss_masking_spans.py b/tests/data/test_loss_masking_spans.py index 521eaf2a9..443a26819 100644 --- a/tests/data/test_loss_masking_spans.py +++ b/tests/data/test_loss_masking_spans.py @@ -1,4 +1,5 @@ import datasets +import pytest from fast_llm.data.dataset.gpt.config import GPTDatasetFromFileConfig, GPTSamplingParameters from fast_llm.data.dataset.memmap import MemmapDataset @@ -10,30 +11,31 @@ from tests.utils.dataset import get_test_dataset_with_loss_masking_spans from tests.utils.global_variables import TOKENIZER_NAME -DATASET_WITH_SPAN_TOKENS = 46199 +DATASET_WITH_SPAN_TOKENS = 45577 DATASET_WITH_SPAN_SAMPLES = { - 27: [49152, 63, 82, 11, 84, 71, 49152], - 30: [49152, 31, 85, 78, 27, 34, 46, 62, 43, 49152], + 27: [49152, 63, 82, 11, 27799, 49152], + 30: [49152, 31, 85, 78, 27, 1448, 62, 43, 49152], 31: [49152, 60, 55, 80, 30, 85, 22, 18, 49152], 77: [49152, 73, 80, 85, 52, 22, 46, 5, 88, 78, 49152], - 87: [49152, 52, 89, 75, 11, 71, 49152], + 87: [49152, 52, 42536, 11, 71, 49152], } HF_LOSS_MASKING_SPANS = { - 27: [[0, 1], [3, 3]], - 30: [[0, 0], [2, 2], [5, 5]], - 31: [[0, 0], [2, 2], [4, 4]], - 77: [[0, 0], [3, 5], [7, 7]], - 87: [[1, 1], [3, 3]], + 27: [[0, 1]], + 30: [[0, 1]], + 31: [[0, 0], [2, 2], [5, 5]], + 77: [[0, 0], [2, 2], [5, 5], [7, 7]], + 87: [[0, 0], [3, 3]], } TOKEN_LOSS_MASKING_SPANS = { - 27: [(1, 3), (4, 5)], - 30: [(1, 2), (3, 4), (6, 7)], - 31: [(1, 2), (3, 4), (5, 6)], - 77: [(1, 2), (4, 7), (8, 9)], - 87: [(2, 3), (4, 5)], + 27: [(1, 3)], + 30: [(1, 3)], + 31: [(1, 2), (3, 4), (6, 7)], + 77: [(1, 2), (3, 4), (6, 7), (8, 9)], + 87: [(1, 2), (3, 4)], } +@pytest.mark.slow def test_gpt_data_with_spans(): _, config, hf_path = get_test_dataset_with_loss_masking_spans() dataset: MemmapDataset[LanguageModelSample] = get_dataset_config(config, GPTDatasetFromFileConfig).build() @@ -74,5 +76,5 @@ def test_gpt_data_with_spans(): document = dataset.get_document( index, parameters=GPTSamplingParameters(num_samples=0, sequence_length=0, use_loss_masking_spans=True) ) - Assert.all_equal(document.tokens.tokens, DATASET_WITH_SPAN_SAMPLES[index]) - Assert.all_equal(document.loss_masking_spans.ranges, TOKEN_LOSS_MASKING_SPANS[index]) + Assert.eq(document.tokens.tokens.tolist(), DATASET_WITH_SPAN_SAMPLES[index]) + Assert.eq(document.loss_masking_spans.ranges, TOKEN_LOSS_MASKING_SPANS[index]) diff --git a/tests/data/test_preference_spans.py b/tests/data/test_preference_spans.py index 7b570c5a1..ef18337eb 100644 --- a/tests/data/test_preference_spans.py +++ b/tests/data/test_preference_spans.py @@ -1,5 +1,6 @@ import datasets import numpy as np +import pytest import torch from fast_llm.data.dataset.gpt.config import GPTDatasetFromFileConfig, GPTSamplingParameters @@ -36,6 +37,7 @@ } +@pytest.mark.slow def test_gpt_data_with_spans(): _, config, hf_path = get_test_dataset_with_preference_spans() dataset: MemmapDataset[LanguageModelSample] = get_dataset_config(config, GPTDatasetFromFileConfig).build() @@ -101,5 +103,5 @@ def test_gpt_data_with_spans(): document = dataset.get_document( index, parameters=GPTSamplingParameters(num_samples=0, sequence_length=0, use_loss_masking_spans=True) ) - Assert.all_equal(document.tokens.tokens, DATASET_WITH_PREFERENCE_SPAN_SAMPLES[index]) - Assert.all_equal(document.chosen_spans.ranges + document.rejected_spans.ranges, TOKEN_PREFERENCE_SPANS[index]) + Assert.eq(document.tokens.tokens.tolist(), DATASET_WITH_PREFERENCE_SPAN_SAMPLES[index]) + Assert.eq(document.chosen_spans.ranges + document.rejected_spans.ranges, TOKEN_PREFERENCE_SPANS[index]) diff --git a/tests/data/test_preparator.py b/tests/data/test_preparator.py index 235135156..729888d9c 100644 --- a/tests/data/test_preparator.py +++ b/tests/data/test_preparator.py @@ -72,7 +72,7 @@ def test_common_prepared_dataset(): for index in COMMON_DATASET_SAMPLES: Assert.eq(hf_dataset[index]["text"], COMMON_DATASET_TEXT[index]) document = dataset.get_document(index, parameters=GPTSamplingParameters(num_samples=0, sequence_length=0)) - Assert.all_equal(document.tokens.tokens, COMMON_DATASET_SAMPLES[index]) + Assert.eq(document.tokens.tokens.tolist(), COMMON_DATASET_SAMPLES[index]) @pytest.mark.slow diff --git a/tests/utils/dataset.py b/tests/utils/dataset.py index ba19916ee..28d28bd94 100644 --- a/tests/utils/dataset.py +++ b/tests/utils/dataset.py @@ -16,27 +16,45 @@ def download_santacoder_tokenizer(): transformers.AutoTokenizer.from_pretrained("bigcode/santacoder").save_pretrained(TOKENIZER_PATH) +def get_random_text( + num_documents: int = 1000, + min_document_size: int = 5, + max_document_size: int = 99, + random_state: np.random.RandomState = np.random, +): + # Randomize document sizes + document_sizes = random_state.randint(min_document_size, max_document_size + 1, num_documents) + size_cumsums = padded_cumsum(document_sizes) + # Generate random ascii characters. + random_text = random_state.randint(32, 127, document_sizes.sum(), dtype=np.uint8).tobytes().decode() + # Gather text by documents. + texts = [ + random_text[size_cumsums[document_index] : size_cumsums[document_index + 1]] + for document_index in range(num_documents) + ] + return texts, document_sizes + + def get_random_spans( - num_documents: int, + document_sizes: np.ndarray, + min_spans: int, max_spans: int, - lengths: np.ndarray | int, random_state: np.random.RandomState = np.random, use_last_format: bool = False, - variable_length: bool = True, ): - if variable_length: - spans = random_state.randint( - 0, lengths[:, None] if isinstance(lengths, np.ndarray) else lengths, [num_documents, max_spans * 2] - ) - else: - spans = [ - random_state.choice(range(length), max_spans * 2, replace=False) - for length in (lengths if isinstance(lengths, np.ndarray) else (lengths for _ in range(num_documents))) - ] - spans = [np.unique(sample_spans).tolist() for sample_spans in np.sort(spans)] + # Randomize span counts. Actual count may be lower for small documents. + span_counts = random_state.randint(min_spans, max_spans + 1, len(document_sizes)) + # Generate random spans. return [ - [(begin, end - use_last_format) for begin, end in zip(sample_spans[::2], sample_spans[1::2], strict=False)] - for sample_spans in spans + [ + (begin, end - use_last_format) + for begin, end in np.sort( + random_state.choice(range(length), min(num_spans, length // 2) * 2, replace=False) + ) + .reshape([-1, 2]) + .tolist() + ] + for length, num_spans in zip(document_sizes, span_counts, strict=True) ] @@ -57,17 +75,14 @@ def _get_hf_test_dataset( seed: int = 1234, num_documents: int = 1000, min_document_size: int = 5, - max_document_size: int = 100, + max_document_size: int = 99, + min_loss_masking_spans: int = 0, max_loss_masking_spans: int = 0, has_preference_spans: bool = False, ): random_state = np.random.RandomState(seed) # Generate random document sizes (character count). - document_sizes = random_state.randint(min_document_size, max_document_size, num_documents) - size_cumsums = padded_cumsum(document_sizes) - # Generate random ascii characters. - random_text = random_state.randint(32, 127, document_sizes.sum(), dtype=np.uint8).tobytes().decode() - texts = [random_text[begin:end] for begin, end in zip(size_cumsums[:-1], size_cumsums[1:])] + texts, document_sizes = get_random_text(num_documents, min_document_size, max_document_size, random_state) if has_preference_spans: dataset_dict = get_random_preference_spans(texts, random_state) @@ -76,7 +91,7 @@ def _get_hf_test_dataset( if max_loss_masking_spans > 0: dataset_dict["loss_masking_spans"] = get_random_spans( - num_documents, max_loss_masking_spans, document_sizes, random_state, use_last_format=True + document_sizes, min_loss_masking_spans, max_loss_masking_spans, random_state, use_last_format=True ) return datasets.Dataset.from_dict(dataset_dict) @@ -90,7 +105,8 @@ def _get_test_dataset( documents_per_shard: int = 10**6, num_documents: int = 1000, min_document_size: int = 5, - max_document_size: int = 100, + max_document_size: int = 99, + min_loss_masking_spans: int = 0, max_loss_masking_spans: int = 0, has_preference_spans: bool = False, splits: dict[str, float] | None = None, @@ -104,7 +120,13 @@ def _get_test_dataset( if not all(config_path.is_file() for config_path in config_paths): dataset = _get_hf_test_dataset( - seed, num_documents, min_document_size, max_document_size, max_loss_masking_spans, has_preference_spans + seed=seed, + num_documents=num_documents, + min_document_size=min_document_size, + max_document_size=max_document_size, + min_loss_masking_spans=min_loss_masking_spans, + max_loss_masking_spans=max_loss_masking_spans, + has_preference_spans=has_preference_spans, ) datasets.DatasetDict({"train": dataset}).save_to_disk(hf_path) source_schema = {"text": "text"} From abe23579fa7bb01181234741d015fb1bd5ed0e54 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Mon, 10 Nov 2025 20:04:01 -0500 Subject: [PATCH 11/12] misc --- fast_llm/data/sample/abstract.py | 5 +--- fast_llm/data/sample/range.py | 13 +++++++---- fast_llm/data/sample/token.py | 40 ++++++++++++++++---------------- 3 files changed, 29 insertions(+), 29 deletions(-) diff --git a/fast_llm/data/sample/abstract.py b/fast_llm/data/sample/abstract.py index aaa321efd..11f5d187c 100644 --- a/fast_llm/data/sample/abstract.py +++ b/fast_llm/data/sample/abstract.py @@ -37,11 +37,8 @@ def from_samples(cls, samples: typing.Iterable[Sample]) -> typing.Self: pass @abc.abstractmethod - def to_samples(self) -> list[Sample]: - pass - def crop(self, begin: int, end: int) -> typing.Self: - return self.from_samples(sample.crop(begin, end) for sample in self.to_samples()) + pass def to_device_(self, device: "torch.device | str"): pass diff --git a/fast_llm/data/sample/range.py b/fast_llm/data/sample/range.py index c3a035376..b7be4efe1 100644 --- a/fast_llm/data/sample/range.py +++ b/fast_llm/data/sample/range.py @@ -15,6 +15,11 @@ from fast_llm.utils import Assert, get_unique +def crop_ranges(ranges: list[tuple[int, int]], begin: int, end: int) -> list[tuple[int, int]]: + cropped_ranges = ((max(begin_ - begin, 0), min(end_ - begin, end - begin)) for begin_, end_ in ranges) + return [(begin_, end_) for begin_, end_ in cropped_ranges if end_ > begin_] + + class RangeSample(Sample): """ A reusable component holding a set of ranges in a sample. @@ -36,9 +41,7 @@ def from_documents(cls, documents: typing.Iterable[typing.Self]) -> typing.Self: return cls(ranges, sample_size) def crop(self, begin: int, end: int) -> typing.Self: - sample_size = end - begin - cropped_ranges = ((max(begin_ - begin, 0), min(end_ - begin, sample_size)) for begin_, end_ in self.ranges) - return self.__class__([(begin_, end_) for begin_, end_ in cropped_ranges if end_ > begin_], sample_size) + return self.__class__(crop_ranges(self.ranges, begin, end), end - begin) def __len__(self) -> int: return self.sample_size @@ -56,8 +59,8 @@ def __init__(self, ranges: list[list[tuple[int, int]]], sample_size: int): def from_samples(cls, samples: typing.Iterable[RangeSample]) -> typing.Self: return cls([sample.ranges for sample in samples], get_unique(sample.sample_size for sample in samples)) - def to_samples(self) -> list[RangeSample]: - return [RangeSample(sample_ranges, self.sample_size) for sample_ranges in self.ranges] + def crop(self, begin: int, end: int) -> typing.Self: + return self.__class__([crop_ranges(sample_ranges, begin, end) for sample_ranges in self.ranges], end - begin) @config_class(dynamic_type={MemmapReaderBaseConfig: "range"}) diff --git a/fast_llm/data/sample/token.py b/fast_llm/data/sample/token.py index 706b5053a..0944f5689 100644 --- a/fast_llm/data/sample/token.py +++ b/fast_llm/data/sample/token.py @@ -16,6 +16,23 @@ from fast_llm.utils import Assert +def crop_lengths(lengths: list[int], begin: int, end: int) -> list[int]: + if len(lengths) == 1: + # Shortcut for the frequent case of a single document. + return [end - begin] + begin_ = 0 + lengths = [] + for length in lengths: + end_ = begin_ + length + cropped_length = min(end_, end) - max(begin_, begin) + if cropped_length > 0: + lengths.append(cropped_length) + if end_ > end: + break + begin_ = end_ + return lengths + + class TokenSample(Sample): def __init__(self, tokens: torch.Tensor, lengths: list[int] | None = None): self.tokens = tokens @@ -34,22 +51,7 @@ def from_documents(cls, documents: typing.Iterable[typing.Self]) -> typing.Self: ) def crop(self, begin: int, end: int) -> typing.Self: - sample_size = end - begin - if self.lengths == [len(self.tokens)]: - # Shortcut for the frequent case of a single document. - lengths = [sample_size] - else: - begin_ = 0 - lengths = [] - for length in self.lengths: - end_ = begin_ + length - cropped_length = min(end_, end) - max(begin_, begin) - if cropped_length > 0: - lengths.append(cropped_length) - if end_ > end: - break - begin_ = end_ - return self.__class__(self.tokens[begin:end], lengths) + return self.__class__(self.tokens[begin:end], crop_lengths(self.lengths, begin, end)) def __len__(self) -> int: return len(self.tokens) @@ -72,12 +74,10 @@ def from_samples(cls, samples: typing.Iterable[TokenSample]) -> typing.Self: [sample.lengths for sample in samples], ) - def to_samples(self) -> list[TokenSample]: - return [TokenSample(tokens, lengths) for tokens, lengths in zip(self.tokens, self.lengths, strict=True)] - def crop(self, begin: int, end: int) -> typing.Self: return self.__class__( - self.tokens[:, begin:end], [sample.crop(begin, end).lengths for sample in self.to_samples()] + self.tokens[:, begin:end], + [crop_lengths(lengths, begin, end) for lengths in self.lengths], ) def to_device_(self, device: "torch.device | str"): From 1801d873d49c4056927d3d61e5808153c7f6e896 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Mon, 10 Nov 2025 20:53:39 -0500 Subject: [PATCH 12/12] fix --- tests/functional/test_functional.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/functional/test_functional.py b/tests/functional/test_functional.py index 3a90745eb..05fafe7a9 100644 --- a/tests/functional/test_functional.py +++ b/tests/functional/test_functional.py @@ -67,7 +67,7 @@ def test_dpo_loss(): reference_model_logits = torch.from_numpy(random_state.normal(size=(10, 50, 100))).to(torch.float32) targets = torch.from_numpy(random_state.randint(0, 100, (10, 50))) - spans = get_random_spans(10, 10, 50, random_state) + spans = get_random_spans(np.full(10, 50), 0, 10, random_state) fastllm_loss, fast_llm_grad = compute_dpo_loss( logits, targets, reference_model_logits, spans[::2], spans[1::2], beta=1, grad_output=1