"""
Transform HeteroData to Graph Transformer sequence input.
This module provides functionality to convert PyG HeteroData objects (typically
batched 2-hop subgraphs) into sequence format suitable for Graph Transformers.
For each anchor node in the batch, the transform extracts its k-hop neighborhood
and creates a fixed-length sequence of node features with padding.
Example Usage:
>>> from torch_geometric.data import HeteroData
>>> from gigl.transforms.graph_transformer import heterodata_to_graph_transformer_input
>>>
>>> # Create batched HeteroData (e.g., from NeighborLoader)
>>> # First batch_size nodes in each node type are anchor nodes
>>> data = HeteroData()
>>> data['user'].x = torch.randn(100, 64) # 100 users, first N are anchors
>>> data['item'].x = torch.randn(50, 32)
>>> data['user', 'buys', 'item'].edge_index = ...
>>>
>>> # Transform to Graph Transformer input
>>> sequences, valid_mask, attention_bias_data = heterodata_to_graph_transformer_input(
... data=data,
... batch_size=32,
... max_seq_len=128,
... anchor_node_type='user',
... )
>>> # sequences: (batch_size, max_seq_len, feature_dim)
>>> # valid_mask: (batch_size, max_seq_len)
With Relative Encodings:
Relative encodings stored as sparse graph-level attributes can be returned as
raw attention-bias features:
>>> from torch_geometric.transforms import Compose
>>> from gigl.transforms.add_positional_encodings import (
... AddHeteroRandomWalkEncodings,
... AddHeteroHopDistanceEncoding,
... )
>>>
>>> # First apply PE transforms to the data
>>> pe_transform = Compose([
... AddHeteroRandomWalkEncodings(walk_length=8),
... AddHeteroHopDistanceEncoding(h_max=5),
... ])
>>> data = pe_transform(data)
>>>
>>> # Transform to sequences with relative bias features
>>> sequences, valid_mask, attention_bias_data = heterodata_to_graph_transformer_input(
... data=data,
... batch_size=32,
... max_seq_len=128,
... anchor_node_type='user',
... anchor_based_attention_bias_attr_names=['hop_distance'],
... )
>>> # sequences: (batch_size, max_seq_len, feature_dim)
>>> # attention_bias_data['anchor_bias']: (batch_size, max_seq_len, 1)
"""
from typing import Literal, Optional, TypedDict
import torch
from torch import Tensor
from torch_geometric.data import Data, HeteroData
from torch_geometric.typing import NodeType
from torch_geometric.utils import to_torch_sparse_tensor
[docs]
class SequenceAuxiliaryData(TypedDict):
[docs]
anchor_bias: Optional[Tensor]
[docs]
pairwise_bias: Optional[Tensor]
[docs]
PPR_WEIGHT_FEATURE_NAME = "ppr_weight"
def _get_node_type_offsets(
data: HeteroData,
node_type_order: list[NodeType],
) -> dict[NodeType, int]:
offsets: dict[NodeType, int] = {}
offset = 0
for node_type in node_type_order:
offsets[node_type] = offset
offset += data[node_type].num_nodes
return offsets
def _validate_ppr_sequence_input(data: HeteroData) -> None:
if not data.edge_types:
raise ValueError(
"sequence_construction_method='ppr' requires at least one PPR edge type."
)
if any(edge_type[1] != "ppr" for edge_type in data.edge_types):
raise ValueError(
"sequence_construction_method='ppr' expects the hetero batch to contain "
f"only PPR edges, got edge types: {data.edge_types}."
)
for edge_type in data.edge_types:
edge_store = data[edge_type]
if not hasattr(edge_store, "edge_attr") or edge_store.edge_attr is None:
raise ValueError(
"sequence_construction_method='ppr' requires every PPR edge type to "
f"have edge_attr weights, but {edge_type} is missing them."
)
def _get_sparse_feature_matrices(
data: HeteroData,
attr_names: Optional[list[str]],
missing_attr_error_prefix: str,
) -> list[Tensor]:
matrices: list[Tensor] = []
for attr_name in attr_names or []:
if not hasattr(data, attr_name):
raise ValueError(
f"{missing_attr_error_prefix} '{attr_name}' not found in data. "
"Make sure to apply the corresponding transform first."
)
matrices.append(getattr(data, attr_name))
return matrices
def _compose_anchor_feature_tensor(
anchor_relative_feature_sequences: Optional[Tensor],
available_anchor_attr_names: list[str],
requested_anchor_attr_names: list[str],
ppr_weight_sequences: Optional[Tensor],
) -> Optional[Tensor]:
if not requested_anchor_attr_names:
return None
feature_parts: list[Tensor] = []
feature_index_by_name = {
attr_name: idx for idx, attr_name in enumerate(available_anchor_attr_names)
}
for attr_name in requested_anchor_attr_names:
if attr_name == PPR_WEIGHT_FEATURE_NAME:
if ppr_weight_sequences is None:
raise ValueError(
f"Requested '{PPR_WEIGHT_FEATURE_NAME}' but it was not computed."
)
feature_parts.append(ppr_weight_sequences)
continue
if anchor_relative_feature_sequences is None:
raise ValueError(
"Anchor-relative features were requested but not computed."
)
if attr_name not in feature_index_by_name:
raise ValueError(
f"Anchor-relative feature '{attr_name}' was requested but not found."
)
feature_idx = feature_index_by_name[attr_name]
feature_parts.append(
anchor_relative_feature_sequences[..., feature_idx : feature_idx + 1]
)
return torch.cat(feature_parts, dim=-1)
def _compose_anchor_feature_dict(
anchor_relative_feature_sequences: Optional[Tensor],
available_anchor_attr_names: list[str],
requested_anchor_attr_names: list[str],
ppr_weight_sequences: Optional[Tensor],
) -> Optional[TokenInputData]:
if not requested_anchor_attr_names:
return None
feature_dict: TokenInputData = {}
feature_index_by_name = {
attr_name: idx for idx, attr_name in enumerate(available_anchor_attr_names)
}
for attr_name in requested_anchor_attr_names:
if attr_name == PPR_WEIGHT_FEATURE_NAME:
if ppr_weight_sequences is None:
raise ValueError(
f"Requested '{PPR_WEIGHT_FEATURE_NAME}' but it was not computed."
)
feature_dict[attr_name] = ppr_weight_sequences
continue
if anchor_relative_feature_sequences is None:
raise ValueError(
"Anchor-relative features were requested but not computed."
)
if attr_name not in feature_index_by_name:
raise ValueError(
f"Anchor-relative feature '{attr_name}' was requested but not found."
)
feature_idx = feature_index_by_name[attr_name]
feature_dict[attr_name] = anchor_relative_feature_sequences[
..., feature_idx : feature_idx + 1
]
return feature_dict
def _build_sequence_layout_from_sparse_neighbors(
reachable: Tensor,
anchor_indices: Tensor,
max_seq_len: int,
include_anchor_first: bool,
device: torch.device,
) -> tuple[Tensor, Tensor]:
"""
Build padded node-index sequences and a validity mask from reachability data.
Returns:
node_index_sequences: (batch_size, max_seq_len) long tensor with -1 padding
valid_mask: (batch_size, max_seq_len) bool tensor
"""
batch_size = anchor_indices.size(0)
node_index_sequences = torch.full(
(batch_size, max_seq_len),
fill_value=-1,
dtype=torch.long,
device=device,
)
valid_mask = torch.zeros(
(batch_size, max_seq_len),
dtype=torch.bool,
device=device,
)
indices = reachable.indices()
batch_idx = indices[0]
node_idx = indices[1]
if include_anchor_first and max_seq_len > 0:
node_index_sequences[:, 0] = anchor_indices
valid_mask[:, 0] = True
if batch_idx.numel() > 0:
keep = node_idx != anchor_indices[batch_idx]
batch_idx = batch_idx[keep]
node_idx = node_idx[keep]
start_pos = 1
else:
start_pos = 0
if batch_idx.numel() == 0 or start_pos >= max_seq_len:
return node_index_sequences, valid_mask
# Compute within-anchor sequence positions for each neighbor node.
#
# After extracting from the sparse tensor, we have flattened arrays:
# batch_idx = [0, 0, 0, 1, 1, 2, 2, 2, 2] <- which anchor each node belongs to
# node_idx = [5, 7, 9, 3, 8, 1, 4, 6, 10] <- the reachable neighbor node IDs
#
# We need to compute sequence positions (1, 2, 3, ...) for each anchor's neighbors:
# positions = [1, 2, 3, 1, 2, 1, 2, 3, 4]
# ^anchor0^ ^a1^ ^--anchor2--^
#
# This allows scattering into the 2D output:
# node_index_sequences[anchor_idx, position] = node_idx
n = batch_idx.size(0)
# Step 1: Mark where each anchor's neighbor group starts
# is_group_start = [1, 0, 0, 1, 0, 1, 0, 0, 0]
# ^ ^ ^
# anchor0 anchor1 anchor2 starts here
is_group_start = torch.zeros(n, dtype=torch.long, device=device)
is_group_start[0] = 1
if n > 1:
is_group_start[1:] = (batch_idx[1:] != batch_idx[:-1]).long()
# Step 2: Assign each node to its anchor group (0-indexed)
# group_id = [0, 0, 0, 1, 1, 2, 2, 2, 2]
group_id = is_group_start.cumsum(0) - 1
# Step 3: Find the starting index of each group in the flattened array
# group_starts = [0, 3, 5] (indices where each anchor's neighbors begin)
group_starts = torch.nonzero(is_group_start, as_tuple=True)[0]
# Step 4: Compute position = (global_index - group_start) + start_pos
# This gives within-group position offset by start_pos (usually 1 for anchor at pos 0)
# positions = [1, 2, 3, 1, 2, 1, 2, 3, 4]
positions = torch.arange(n, device=device) - group_starts[group_id] + start_pos
# Step 5: Filter out positions that exceed max_seq_len (truncation)
valid = positions < max_seq_len
valid_batch_idx = batch_idx[valid]
valid_positions = positions[valid]
valid_node_idx = node_idx[valid]
# Step 6: Scatter valid nodes into the output tensors
node_index_sequences[valid_batch_idx, valid_positions] = valid_node_idx
valid_mask[valid_batch_idx, valid_positions] = True
return node_index_sequences, valid_mask
def _build_sequence_layout_from_ppr_edges(
homo_data: Data,
anchor_indices: Tensor,
max_seq_len: int,
include_anchor_first: bool,
num_nodes: int,
device: torch.device,
return_edge_weights: bool = False,
) -> tuple[Tensor, Tensor, Optional[Tensor]]:
"""Build sequences directly from outgoing PPR edges for each anchor.
The sequence order is:
1. Anchor node first, when ``include_anchor_first`` is True.
2. Destination nodes reachable by outgoing ``"ppr"`` edges from that anchor,
sorted by descending PPR weight.
"""
batch_size = anchor_indices.size(0)
node_index_sequences = torch.full(
(batch_size, max_seq_len),
fill_value=-1,
dtype=torch.long,
device=device,
)
valid_mask = torch.zeros(
(batch_size, max_seq_len),
dtype=torch.bool,
device=device,
)
ppr_weight_sequences = None
if return_edge_weights:
ppr_weight_sequences = torch.zeros(
(batch_size, max_seq_len, 1),
dtype=torch.float,
device=device,
)
if include_anchor_first and max_seq_len > 0:
node_index_sequences[:, 0] = anchor_indices
valid_mask[:, 0] = True
start_pos = 1
else:
start_pos = 0
if start_pos >= max_seq_len:
return node_index_sequences, valid_mask, ppr_weight_sequences
if not hasattr(homo_data, "edge_attr") or homo_data.edge_attr is None:
raise ValueError(
"sequence_construction_method='ppr' requires homogeneous edge_attr weights."
)
edge_weights = homo_data.edge_attr
if edge_weights.dim() == 2:
if edge_weights.size(1) != 1:
raise ValueError(
"PPR edge weights must be 1D or shape [N, 1], "
f"got {tuple(edge_weights.shape)}."
)
edge_weights = edge_weights.squeeze(1)
elif edge_weights.dim() != 1:
raise ValueError(
"PPR edge weights must be 1D or shape [N, 1], "
f"got {tuple(edge_weights.shape)}."
)
anchor_batch_index_by_homo_idx = torch.full(
(num_nodes,),
fill_value=-1,
dtype=torch.long,
device=device,
)
anchor_batch_index_by_homo_idx[anchor_indices] = torch.arange(
batch_size, device=device
)
src_idx = homo_data.edge_index[0]
dst_idx = homo_data.edge_index[1]
anchor_batch_idx = anchor_batch_index_by_homo_idx[src_idx]
keep = anchor_batch_idx >= 0
if not keep.any():
return node_index_sequences, valid_mask, ppr_weight_sequences
all_anchor_batch_idx = anchor_batch_idx[keep]
all_dst_idx = dst_idx[keep]
all_weights = edge_weights[keep]
if include_anchor_first:
keep = all_dst_idx != anchor_indices[all_anchor_batch_idx]
if not keep.any():
return node_index_sequences, valid_mask, ppr_weight_sequences
all_anchor_batch_idx = all_anchor_batch_idx[keep]
all_dst_idx = all_dst_idx[keep]
all_weights = all_weights[keep]
# Flattened COO edges can be laid out in one pass by sorting first on weight
# and then stably on anchor batch id, which preserves descending-weight order
# within each anchor group without a Python loop.
weight_order = torch.argsort(all_weights, descending=True, stable=True)
all_anchor_batch_idx = all_anchor_batch_idx[weight_order]
all_dst_idx = all_dst_idx[weight_order]
all_weights = all_weights[weight_order]
batch_order = torch.argsort(all_anchor_batch_idx, stable=True)
sorted_batch_idx = all_anchor_batch_idx[batch_order]
sorted_dst_idx = all_dst_idx[batch_order]
sorted_weights = all_weights[batch_order]
n = sorted_batch_idx.size(0)
is_group_start = torch.zeros(n, dtype=torch.long, device=device)
is_group_start[0] = 1
if n > 1:
is_group_start[1:] = (sorted_batch_idx[1:] != sorted_batch_idx[:-1]).long()
group_id = is_group_start.cumsum(0) - 1
group_starts = torch.nonzero(is_group_start, as_tuple=True)[0]
positions = torch.arange(n, device=device) - group_starts[group_id] + start_pos
valid = positions < max_seq_len
valid_batch_idx = sorted_batch_idx[valid]
valid_positions = positions[valid]
valid_dst_idx = sorted_dst_idx[valid]
valid_weights = sorted_weights[valid]
node_index_sequences[valid_batch_idx, valid_positions] = valid_dst_idx
valid_mask[valid_batch_idx, valid_positions] = True
if ppr_weight_sequences is not None:
ppr_weight_sequences[
valid_batch_idx, valid_positions, 0
] = valid_weights.float()
return node_index_sequences, valid_mask, ppr_weight_sequences
def _gather_sequences_from_node_indices(
node_index_sequences: Tensor,
node_features: Tensor,
valid_mask: Tensor,
padding_value: float,
) -> Tensor:
"""Gather node features into padded sequences using precomputed node indices.
Args:
node_index_sequences: (batch_size, max_seq_len) node indices
node_features: (num_nodes, feature_dim) node features
valid_mask: (batch_size, max_seq_len) bool tensor indicating valid positions
padding_value: Value to use for padding
Returns:
(batch_size, max_seq_len, feature_dim) padded sequences.
"""
batch_size, max_seq_len = node_index_sequences.shape
feature_dim = node_features.size(-1)
sequences = torch.full(
(batch_size, max_seq_len, feature_dim),
padding_value,
dtype=node_features.dtype,
device=node_features.device,
)
if feature_dim == 0 or not valid_mask.any():
return sequences
sequences[valid_mask] = node_features[node_index_sequences[valid_mask]]
return sequences
def _lookup_anchor_relative_features(
anchor_indices: Tensor,
node_index_sequences: Tensor,
valid_mask: Tensor,
csr_matrices: Optional[list[Tensor]],
device: torch.device,
) -> Optional[Tensor]:
"""
Look up anchor-relative sparse values for each valid token in the sequence.
For each node in the sequence, this looks up the value PE[anchor_idx, node_idx]
from each provided sparse CSR matrix. This captures the relationship between
each sequence token and its anchor node (e.g., hop distance from anchor).
Args:
anchor_indices: (batch_size,) anchor node indices in homogeneous graph
node_index_sequences: (batch_size, max_seq_len) node indices for each sequence position
valid_mask: (batch_size, max_seq_len) bool tensor indicating valid positions
csr_matrices: List of sparse CSR matrices, each (num_nodes, num_nodes)
device: Device for output tensor
Returns:
features: (batch_size, max_seq_len, num_attrs) tensor where
features[b, i, k] = csr_matrices[k][anchor_indices[b], node_index_sequences[b, i]]
for valid positions, 0.0 for padding positions.
Returns None if csr_matrices is empty or None.
Example:
# batch_size=2, max_seq_len=4, num_attrs=1 (e.g., hop_distance)
# anchor_indices = [10, 20] (anchor nodes)
# node_index_sequences = [[10, 5, 7, -1], # anchor 10's sequence
# [20, 3, 8, 9]] # anchor 20's sequence
# valid_mask = [[T, T, T, F], [T, T, T, T]]
#
# Output shape: (2, 4, 1)
# features[0, :, 0] = [hop_dist[10,10], hop_dist[10,5], hop_dist[10,7], 0.0]
# features[1, :, 0] = [hop_dist[20,20], hop_dist[20,3], hop_dist[20,8], hop_dist[20,9]]
"""
if not csr_matrices:
return None
batch_size, max_seq_len = node_index_sequences.shape
num_attrs = len(csr_matrices)
features = torch.zeros(
(batch_size, max_seq_len, num_attrs),
dtype=torch.float,
device=device,
)
if not valid_mask.any():
return features
valid_batch_idx, valid_pos_idx = torch.nonzero(valid_mask, as_tuple=True)
valid_node_idx = node_index_sequences[valid_mask]
anchor_for_entry = anchor_indices[valid_batch_idx]
for attr_idx, pe_matrix in enumerate(csr_matrices):
pe_values = _lookup_csr_values(
csr_matrix=pe_matrix,
row_indices=anchor_for_entry,
col_indices=valid_node_idx,
)
features[valid_batch_idx, valid_pos_idx, attr_idx] = pe_values
return features
def _lookup_pairwise_relative_features(
node_index_sequences: Tensor,
valid_mask: Tensor,
csr_matrices: Optional[list[Tensor]],
device: torch.device,
) -> Optional[Tensor]:
"""
Look up pairwise sparse values for each valid token pair in the sequence.
For each pair of nodes (i, j) in the sequence, this looks up the value
PE[node_i, node_j] from each provided sparse CSR matrix. This captures
pairwise relationships between all sequence tokens (e.g., random walk
structural encoding between any two nodes).
The output is typically used as attention bias in Graph Transformers,
added to attention scores before softmax.
Args:
node_index_sequences: (batch_size, max_seq_len) node indices for each sequence position
valid_mask: (batch_size, max_seq_len) bool tensor indicating valid positions
csr_matrices: List of sparse CSR matrices, each (num_nodes, num_nodes)
device: Device for output tensor
Returns:
features: (batch_size, max_seq_len, max_seq_len, num_attrs) tensor where
features[b, i, j, k] = csr_matrices[k][node_index_sequences[b, i], node_index_sequences[b, j]]
for valid (i, j) pairs, 0.0 for padding positions.
Returns None if csr_matrices is empty.
Example:
# batch_size=2, max_seq_len=3, num_attrs=1 (e.g., random_walk_se)
# node_index_sequences = [[10, 5, 7], # anchor 0's sequence
# [20, 3, -1]] # anchor 1's sequence (padded)
# valid_mask = [[T, T, T], [T, T, F]]
#
# Output shape: (2, 3, 3, 1)
#
# For batch 0 (all valid), features[0, :, :, 0] is a 3x3 matrix:
# node 10 node 5 node 7
# node 10 [PE[10,10], PE[10,5], PE[10,7]]
# node 5 [PE[5,10], PE[5,5], PE[5,7]]
# node 7 [PE[7,10], PE[7,5], PE[7,7]]
#
# For batch 1 (position 2 is padding), features[1, :, :, 0]:
# node 20 node 3 (pad)
# node 20 [PE[20,20], PE[20,3], 0.0]
# node 3 [PE[3,20], PE[3,3], 0.0]
# (pad) [0.0, 0.0, 0.0]
"""
if not csr_matrices:
return None
batch_size, max_seq_len = node_index_sequences.shape
num_attrs = len(csr_matrices)
features = torch.zeros(
(batch_size, max_seq_len, max_seq_len, num_attrs),
dtype=torch.float,
device=device,
)
pair_valid_mask = valid_mask.unsqueeze(2) & valid_mask.unsqueeze(1)
if not pair_valid_mask.any():
return features
row_indices = node_index_sequences.unsqueeze(2).expand(-1, -1, max_seq_len)
col_indices = node_index_sequences.unsqueeze(1).expand(-1, max_seq_len, -1)
valid_row_indices = row_indices[pair_valid_mask]
valid_col_indices = col_indices[pair_valid_mask]
for attr_idx, pe_matrix in enumerate(csr_matrices):
pe_values = _lookup_csr_values(
csr_matrix=pe_matrix,
row_indices=valid_row_indices,
col_indices=valid_col_indices,
)
features[..., attr_idx][pair_valid_mask] = pe_values
return features
def _get_k_hop_neighbors_sparse(
anchor_indices: Tensor,
edge_index: Tensor,
num_nodes: int,
k: int,
device: torch.device,
) -> Tensor:
"""
Get k-hop reachable nodes for all anchors using sparse matrix multiplication.
Follows the same efficient pattern as AddHeteroRandomWalkPE: simple sparse
matrix powers to accumulate reachability without expensive membership checks.
Args:
anchor_indices: (batch_size,) anchor node indices
edge_index: (2, num_edges) edge index
num_nodes: Total number of nodes
k: Number of hops
device: Device for tensors
Returns:
reachable: (batch_size, num_nodes) sparse tensor, non-zero if node is reachable within k hops
"""
batch_size = anchor_indices.size(0)
# Build sparse adjacency matrix (binarized) - coalesce once
adj = to_torch_sparse_tensor(edge_index, size=(num_nodes, num_nodes)).coalesce()
adj = torch.sparse_coo_tensor(
adj.indices(),
torch.ones(adj.indices().size(1), device=device, dtype=torch.float),
size=(num_nodes, num_nodes),
) # No coalesce needed - indices already unique from previous coalesce
# Initialize: sparse matrix where row i has a 1 at column anchor_indices[i]
reachable = torch.sparse_coo_tensor(
torch.stack(
[
torch.arange(batch_size, device=device),
anchor_indices,
]
),
torch.ones(batch_size, device=device, dtype=torch.float),
size=(batch_size, num_nodes),
) # No coalesce needed - indices are unique
current = reachable
for _ in range(k):
# Expand: current @ adj gives nodes reachable in one more hop
current = torch.sparse.mm(current, adj)
if current._nnz() == 0:
break
# Accumulate into reachable
reachable = reachable + current
# Coalesce and binarize final result
reachable = reachable.coalesce()
reachable = torch.sparse_coo_tensor(
reachable.indices(),
torch.ones(reachable._nnz(), device=device, dtype=torch.float),
size=(batch_size, num_nodes),
).coalesce()
return reachable
def _lookup_csr_values(
csr_matrix: Tensor,
row_indices: Tensor,
col_indices: Tensor,
default_value: float = 0.0,
) -> Tensor:
"""
Look up values in a CSR sparse matrix for given (row, col) pairs.
Vectorized CSR lookup: for each query, slice the row and search for the column.
Time complexity: O(n * avg_nnz_per_row), typically O(n) for sparse matrices.
Args:
csr_matrix: (num_rows, num_cols) sparse CSR tensor
row_indices: (n,) row indices to look up
col_indices: (n,) column indices to look up
default_value: Value for missing entries (default: 0.0)
Returns:
(n,) values from csr_matrix[row, col], or default_value if not present
"""
n = row_indices.size(0)
device = row_indices.device
if n == 0:
return torch.zeros(0, device=device, dtype=torch.float)
crow_indices = csr_matrix.crow_indices()
col_indices_csr = csr_matrix.col_indices()
values_csr = csr_matrix.values()
# Get row start/end pointers
row_starts = crow_indices[row_indices]
row_ends = crow_indices[row_indices + 1]
row_lengths = row_ends - row_starts
max_row_len = row_lengths.max().item()
if max_row_len == 0:
return torch.full((n,), default_value, device=device, dtype=torch.float)
# Build offset matrix: (n, max_row_len)
offsets = row_starts.unsqueeze(1) + torch.arange(max_row_len, device=device)
valid_mask = offsets < row_ends.unsqueeze(1)
# Safe indexing with clamping
nnz = col_indices_csr.size(0)
offsets_clamped = offsets.clamp(max=max(nnz - 1, 0))
# Get columns at offsets and find matches
cols_at_offsets = col_indices_csr[offsets_clamped]
col_matches = (cols_at_offsets == col_indices.unsqueeze(1)) & valid_mask
# Find which queries have matches
found = col_matches.any(dim=1)
# Initialize output
result = torch.full((n,), default_value, device=device, dtype=torch.float)
if found.any():
# Get match positions and retrieve values
match_offsets = col_matches.float().argmax(dim=1)
value_indices = row_starts[found] + match_offsets[found]
result[found] = values_csr[value_indices].float()
return result