Source code for gigl.distributed.dist_link_prediction_dataset
from multiprocessing.reduction import ForkingPickler
from gigl.common.logger import Logger
from gigl.distributed.dist_dataset import DistDataset, _reduce_distributed_dataset
# TODO (mkolodner-sc): Deprecate this class in favor of DistDataset
[docs]
class DistLinkPredictionDataset(DistDataset):
def __init__(self, *args, **kwargs):
logger.warning(
"DistLinkPredictionDataset is deprecated. Please use DistDataset instead."
)
super().__init__(*args, **kwargs)
# Register custom serialization for DistLinkPredictionDataset with multiprocessing's ForkingPickler.
# This enables DistLinkPredictionDataset objects to be safely passed between processes by using
# IPC handles instead of trying to pickle the underlying shared memory directly,
# which would fail or cause data corruption.
ForkingPickler.register(DistLinkPredictionDataset, _reduce_distributed_dataset)