dicee.trainer.torch_trainer_fsdp ================================ .. py:module:: dicee.trainer.torch_trainer_fsdp .. autoapi-nested-parse:: Multi-GPU trainer: row-wise sharded entity embeddings + FSDP dense parameters. Entity embeddings are sharded row-wise across ranks via dist.all_to_all_single. Each rank owns a contiguous slice of entity rows; forward and backward use all_to_all to route index requests and gradient returns to the owning rank. A per-rank _LocalSparseAdam updates only the rows that received a non-zero gradient each batch. Dense parameters (relation embeddings, Clifford coefficients, …) are wrapped with FSDP and updated by the standard dense optimizer. Attributes ---------- .. autoapisummary:: dicee.trainer.torch_trainer_fsdp.OptimizedModule Classes ------- .. autoapisummary:: dicee.trainer.torch_trainer_fsdp.TorchFSDPTrainer Module Contents --------------- .. py:data:: OptimizedModule :value: None .. py:class:: TorchFSDPTrainer(args, callbacks) Bases: :py:obj:`dicee.abstracts.AbstractTrainer` Single-node multi-GPU trainer: row-wise sharded entity embeddings + FSDP. Entity embeddings ----------------- Sharded row-wise across GPUs. all_to_all routes each index request to its owning rank; embeddings and gradients travel back the same way. _LocalSparseAdam updates only the rows accessed in each batch — O(active_rows) not O(shard_size). Each entity receives gradient from every rank's batches. Dense parameters (relation embeddings, Clifford coefficients, …) ---------------------------------------------------------------- Wrapped with FSDP; updated by the standard dense optimizer step(). Usage ----- Launched with torchrun: torchrun --nproc_per_node=4 -u dicee --trainer torchFSDP ... .. py:attribute:: local_rank .. py:attribute:: global_rank .. py:attribute:: is_global_zero .. py:attribute:: device .. py:attribute:: model :value: None .. py:attribute:: raw_model :value: None .. py:attribute:: optimizer :value: None .. py:attribute:: loss_func :value: None .. py:attribute:: train_dataset_loader :value: None .. py:attribute:: ptdtype .. py:attribute:: ctx .. py:attribute:: scaler .. py:attribute:: sharding_strategy .. py:attribute:: gradient_clip_val .. py:attribute:: use_compile .. py:attribute:: num_workers .. py:attribute:: prefetch_factor .. py:attribute:: entity_optim_device .. py:method:: fit(*args, **kwargs) .. py:method:: extract_input_outputs(z)