dicee.trainer.torch_trainer_fsdp
Multi-GPU trainer: row-wise sharded entity embeddings + FSDP dense parameters.
Entity embeddings are sharded row-wise across ranks via dist.all_to_all_single. Each rank owns a contiguous slice of entity rows; forward and backward use all_to_all to route index requests and gradient returns to the owning rank. A per-rank _LocalSparseAdam updates only the rows that received a non-zero gradient each batch.
Dense parameters (relation embeddings, Clifford coefficients, …) are wrapped with FSDP and updated by the standard dense optimizer.
Attributes
Classes
Single-node multi-GPU trainer: row-wise sharded entity embeddings + FSDP. |
Module Contents
- class dicee.trainer.torch_trainer_fsdp.TorchFSDPTrainer(args, callbacks)[source]
Bases:
dicee.abstracts.AbstractTrainerSingle-node multi-GPU trainer: row-wise sharded entity embeddings + FSDP.
Entity embeddings
Sharded row-wise across GPUs. all_to_all routes each index request to its owning rank; embeddings and gradients travel back the same way. _LocalSparseAdam updates only the rows accessed in each batch — O(active_rows) not O(shard_size). Each entity receives gradient from every rank’s batches.
Dense parameters (relation embeddings, Clifford coefficients, …)
Wrapped with FSDP; updated by the standard dense optimizer step().
Usage
- Launched with torchrun:
torchrun –nproc_per_node=4 -u dicee –trainer torchFSDP …
- local_rank
- global_rank
- is_global_zero
- device
- model = None
- raw_model = None
- optimizer = None
- loss_func = None
- train_dataset_loader = None
- ptdtype
- ctx
- scaler
- sharding_strategy
- gradient_clip_val
- use_compile
- num_workers
- prefetch_factor
- entity_optim_device