Code: Select all
import json
from torch.utils.data import Dataset, ConcatDataset, DataLoader
class JSONLDataset(Dataset):
def __init__(self, file_path, max_seq_len):
super().__init__()
self.file_path = file_path
with open(file_path, 'r') as f:
self.offsets = []
offset = []
for line in f:
self.offsets.append(offset)
offset += len(line.encode('utf-8'))
self.max_seq_len = max_seq_len
def __len__(self):
return len(self.offsets)
def __getitem__(self, index):
if self.file is None:
self.file = open(self.file_path, 'r', encoding='utf-8')
self.file.seek(self.offsets[index])
line = self.file.readline()
return json.loads(line)
datasets = [JSONLDataset(f"./data/file_{i}.jsonl") for i in range(50)]
big_dataset = ConcatDataset(datasets)
dataloader = DataLoader(big_dataset, batch_size=8, shuffle=True)
Mobile version