dicee.trainer.torch_trainer_fsdp

Multi-GPU trainer: row-wise sharded entity embeddings + FSDP dense parameters.

Entity embeddings are sharded row-wise across ranks via dist.all_to_all_single. Each rank owns a contiguous slice of entity rows; forward and backward use all_to_all to route index requests and gradient returns to the owning rank. A per-rank _LocalSparseAdam updates only the rows that received a non-zero gradient each batch.

Dense parameters (relation embeddings, Clifford coefficients, …) are wrapped with FSDP and updated by the standard dense optimizer.

Attributes

OptimizedModule

Classes

TorchFSDPTrainer

Single-node multi-GPU trainer: row-wise sharded entity embeddings + FSDP.

Module Contents

dicee.trainer.torch_trainer_fsdp.OptimizedModule = None[source]
class dicee.trainer.torch_trainer_fsdp.TorchFSDPTrainer(args, callbacks)[source]

Bases: dicee.abstracts.AbstractTrainer

Single-node multi-GPU trainer: row-wise sharded entity embeddings + FSDP.

Entity embeddings

Sharded row-wise across GPUs. all_to_all routes each index request to its owning rank; embeddings and gradients travel back the same way. _LocalSparseAdam updates only the rows accessed in each batch — O(active_rows) not O(shard_size). Each entity receives gradient from every rank’s batches.

Dense parameters (relation embeddings, Clifford coefficients, …)

Wrapped with FSDP; updated by the standard dense optimizer step().

Usage

Launched with torchrun:

torchrun –nproc_per_node=4 -u dicee –trainer torchFSDP …

local_rank
global_rank
is_global_zero
device
model = None
raw_model = None
optimizer = None
loss_func = None
train_dataset_loader = None
ptdtype
ctx
scaler
sharding_strategy
gradient_clip_val
use_compile
num_workers
prefetch_factor
entity_optim_device
fit(*args, **kwargs)[source]
extract_input_outputs(z)[source]