Source code for dicee.models.clifford

from .base_model import BaseKGE
import torch
[docs] class Keci(BaseKGE): def __init__(self, args): super().__init__(args) self.name = 'Keci' self.p = self.args.get("p", 0) self.q = self.args.get("q", 0) if self.p is None: self.p = 0 if self.q is None: self.q = 0 self.r = self.embedding_dim / (self.p + self.q + 1) try: assert self.r.is_integer() except AssertionError: raise AssertionError(f'r = embedding_dim / (p + q+ 1) must be a whole number\n' f'Currently {self.r}={self.embedding_dim} / ({self.p}+ {self.q} +1)') self.r = int(self.r) self.requires_grad_for_interactions = True # Initialize parameters for dimension scaling # TODO:Do we need coefficients for the real part ? if self.p > 0: self.p_coefficients = torch.nn.Embedding(num_embeddings=1, embedding_dim=self.p) torch.nn.init.zeros_(self.p_coefficients.weight) if self.q > 0: self.q_coefficients = torch.nn.Embedding(num_embeddings=1, embedding_dim=self.q) torch.nn.init.zeros_(self.q_coefficients.weight)
[docs] def compute_sigma_pp(self, hp, rp): """ Compute sigma_{pp} = \sum_{i=1}^{p-1} \sum_{k=i+1}^p (h_i r_k - h_k r_i) e_i e_k sigma_{pp} captures the interactions between along p bases For instance, let p e_1, e_2, e_3, we compute interactions between e_1 e_2, e_1 e_3 , and e_2 e_3 This can be implemented with a nested two for loops results = [] for i in range(p - 1): for k in range(i + 1, p): results.append(hp[:, :, i] * rp[:, :, k] - hp[:, :, k] * rp[:, :, i]) sigma_pp = torch.stack(results, dim=2) assert sigma_pp.shape == (b, r, int((p * (p - 1)) / 2)) Yet, this computation would be quite inefficient. Instead, we compute interactions along all p, e.g., e1e1, e1e2, e1e3, e2e1, e2e2, e2e3, e3e1, e3e2, e3e3 Then select the triangular matrix without diagonals: e1e2, e1e3, e2e3. """ # Compute indexes for the upper triangle of p by p matrix indices = torch.triu_indices(self.p, self.p, offset=1) # Compute p by p operations sigma_pp = torch.einsum('nrp,nrx->nrpx', hp, rp) - torch.einsum('nrx,nrp->nrpx', hp, rp) sigma_pp = sigma_pp[:, :, indices[0], indices[1]] return sigma_pp
[docs] def compute_sigma_qq(self, hq, rq): """ Compute sigma_{qq} = \sum_{j=1}^{p+q-1} \sum_{k=j+1}^{p+q} (h_j r_k - h_k r_j) e_j e_k sigma_{q} captures the interactions between along q bases For instance, let q e_1, e_2, e_3, we compute interactions between e_1 e_2, e_1 e_3 , and e_2 e_3 This can be implemented with a nested two for loops results = [] for j in range(q - 1): for k in range(j + 1, q): results.append(hq[:, :, j] * rq[:, :, k] - hq[:, :, k] * rq[:, :, j]) sigma_qq = torch.stack(results, dim=2) assert sigma_qq.shape == (b, r, int((q * (q - 1)) / 2)) Yet, this computation would be quite inefficient. Instead, we compute interactions along all p, e.g., e1e1, e1e2, e1e3, e2e1, e2e2, e2e3, e3e1, e3e2, e3e3 Then select the triangular matrix without diagonals: e1e2, e1e3, e2e3. """ # Compute indexes for the upper triangle of p by p matrix if self.q > 1: indices = torch.triu_indices(self.q, self.q, offset=1) # Compute p by p operations sigma_qq = torch.einsum('nrp,nrx->nrpx', hq, rq) - torch.einsum('nrx,nrp->nrpx', hq, rq) sigma_qq = sigma_qq[:, :, indices[0], indices[1]] else: sigma_qq = torch.zeros((len(hq), self.r, int((self.q * (self.q - 1)) / 2))) return sigma_qq
[docs] def compute_sigma_pq(self, *, hp, hq, rp, rq): """ \sum_{i=1}^{p} \sum_{j=p+1}^{p+q} (h_i r_j - h_j r_i) e_i e_j results = [] sigma_pq = torch.zeros(b, r, p, q) for i in range(p): for j in range(q): sigma_pq[:, :, i, j] = hp[:, :, i] * rq[:, :, j] - hq[:, :, j] * rp[:, :, i] print(sigma_pq.shape) """ sigma_pq = torch.einsum('nrp,nrq->nrpq', hp, rq) - torch.einsum('nrq,nrp->nrpq', hq, rp) assert sigma_pq.shape[1:] == (self.r, self.p, self.q) return sigma_pq
[docs] def apply_coefficients(self, hp, hq, rp, rq): """ Multiplying a base vector with its scalar coefficient """ if self.p > 0: hp = hp * self.p_coefficients.weight rp = rp * self.p_coefficients.weight if self.q > 0: hq = hq * self.q_coefficients.weight rq = rq * self.q_coefficients.weight return hp, hq, rp, rq
[docs] def clifford_multiplication(self, h0, hp, hq, r0, rp, rq): """ Compute our CL multiplication h = h_0 + \sum_{i=1}^p h_i e_i + \sum_{j=p+1}^{p+q} h_j e_j r = r_0 + \sum_{i=1}^p r_i e_i + \sum_{j=p+1}^{p+q} r_j e_j ei ^2 = +1 for i =< i =< p ej ^2 = -1 for p < j =< p+q ei ej = -eje1 for i \neq j h r = sigma_0 + sigma_p + sigma_q + sigma_{pp} + sigma_{q}+ sigma_{pq} where (1) sigma_0 = h_0 r_0 + \sum_{i=1}^p (h_0 r_i) e_i - \sum_{j=p+1}^{p+q} (h_j r_j) e_j (2) sigma_p = \sum_{i=1}^p (h_0 r_i + h_i r_0) e_i (3) sigma_q = \sum_{j=p+1}^{p+q} (h_0 r_j + h_j r_0) e_j (4) sigma_{pp} = \sum_{i=1}^{p-1} \sum_{k=i+1}^p (h_i r_k - h_k r_i) e_i e_k (5) sigma_{qq} = \sum_{j=1}^{p+q-1} \sum_{k=j+1}^{p+q} (h_j r_k - h_k r_j) e_j e_k (6) sigma_{pq} = \sum_{i=1}^{p} \sum_{j=p+1}^{p+q} (h_i r_j - h_j r_i) e_i e_j """ n = len(h0) assert h0.shape == (n, self.r) == r0.shape == (n, self.r) assert hp.shape == (n, self.r, self.p) == rp.shape == (n, self.r, self.p) assert hq.shape == (n, self.r, self.q) == rq.shape == (n, self.r, self.q) # (1) sigma_0 = h0 * r0 + torch.sum(hp * rp, dim=2) - torch.sum(hq * rq, dim=2) assert sigma_0.shape == (n, self.r) # (2) sigma_p = torch.einsum('nr,nrp->nrp', h0, rp) + torch.einsum('nr,nrp->nrp', r0, hp) assert sigma_p.shape == (n, self.r, self.p) # (3) sigma_q = torch.einsum('nr,nrq->nrq', h0, rq) + torch.einsum('nr,nrq->nrq', r0, hq) # (4) sigma_pp = self.compute_sigma_pp(hp, rp) # (5) sigma_qq = self.compute_sigma_qq(hq, rq) # (6) sigma_pq = torch.einsum('bkp,bkq->bkpq', hp, rq) - torch.einsum('bkp,bkq->bkpq', rp, hq) assert sigma_pq.shape == (n, self.r, self.p, self.q) return sigma_0, sigma_p, sigma_q, sigma_pp, sigma_qq, sigma_pq
[docs] def construct_cl_multivector(self, x: torch.FloatTensor, r: int, p: int, q: int) -> tuple[ torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]: """ Construct a batch of multivectors Cl_{p,q}(\mathbb{R}^d) Parameter --------- x: torch.FloatTensor with (n,d) shape Returns ------- a0: torch.FloatTensor with (n,r) shape ap: torch.FloatTensor with (n,r,p) shape aq: torch.FloatTensor with (n,r,q) shape """ batch_size, d = x.shape # (1) A_{n \times k}: take the first k columns a0 = x[:, :r].view(batch_size, r) # (2) B_{n \times p}, C_{n \times q}: take the self.k * self.p columns after the k. column if p > 0: ap = x[:, r: r + (r * p)].view(batch_size, r, p) else: ap = torch.zeros((batch_size, r, p), device=self.device) if q > 0: # (3) B_{n \times p}, C_{n \times q}: take the last self.r * self.q . aq = x[:, -(r * q):].view(batch_size, r, q) else: aq = torch.zeros((batch_size, r, q), device=self.device) return a0, ap, aq
[docs] def forward_k_vs_with_explicit(self, x: torch.Tensor): n = len(x) # (1) Retrieve real-valued embedding vectors. head_ent_emb, rel_ent_emb = self.get_head_relation_representation(x) # (2) Construct multi-vector in Cl_{p,q} (\mathbb{R}^d) for head entities and relations h0, hp, hq = self.construct_cl_multivector(head_ent_emb, r=self.r, p=self.p, q=self.q) r0, rp, rq = self.construct_cl_multivector(rel_ent_emb, r=self.r, p=self.p, q=self.q) E = self.entity_embeddings.weight # Clifford mul. sigma_0 = h0 * r0 + torch.sum(hp * rp, dim=2) - torch.sum(hq * rq, dim=2) sigma_p = torch.einsum('nr,nrp->nrp', h0, rp) + torch.einsum('nrp, nr->nrp', hp, r0) sigma_q = torch.einsum('nr,nrq->nrq', h0, rq) + torch.einsum('nrq, nr->nrq', hq, r0) t0 = E[:, :self.r] score_sigma_0 = sigma_0 @ t0.transpose(1, 0) if self.p > 0: tp = E[:, self.r: self.r + (self.r * self.p)].view(self.num_entities, self.r, self.p) score_sigma_p = torch.einsum('bkp,ekp->be', sigma_p, tp) else: score_sigma_p = 0 if self.q > 0: tq = E[:, -(self.r * self.q):].view(self.num_entities, self.r, self.q) score_sigma_q = torch.einsum('bkp,ekp->be', sigma_q, tq) else: score_sigma_q = 0 # Compute sigma_pp sigma_qq and sigma_pq if self.p > 1: results = [] for i in range(self.p - 1): for k in range(i + 1, self.p): results.append(hp[:, :, i] * rp[:, :, k] - hp[:, :, k] * rp[:, :, i]) sigma_pp = torch.stack(results, dim=2) assert sigma_pp.shape == (n, self.r, int((self.p * (self.p - 1)) / 2)) sigma_pp = torch.sum(sigma_pp, dim=[1, 2]).view(n, 1) del results else: sigma_pp = 0 if self.q > 1: results = [] for j in range(self.q - 1): for k in range(j + 1, self.q): results.append(hq[:, :, j] * rq[:, :, k] - hq[:, :, k] * rq[:, :, j]) sigma_qq = torch.stack(results, dim=2) del results assert sigma_qq.shape == (n, self.r, int((self.q * (self.q - 1)) / 2)) sigma_qq = torch.sum(sigma_qq, dim=[1, 2]).view(n, 1) else: sigma_qq = 0 if self.p >= 1 and self.q >= 1: sigma_pq = torch.zeros(n, self.r, self.p, self.q) for i in range(self.p): for j in range(self.q): sigma_pq[:, :, i, j] = hp[:, :, i] * rq[:, :, j] - hq[:, :, j] * rp[:, :, i] sigma_pq = torch.sum(sigma_pq, dim=[1, 2, 3]).view(n, 1) else: sigma_pq = 0 return score_sigma_0 + score_sigma_p + score_sigma_q + sigma_pp + sigma_qq + sigma_pq
[docs] def k_vs_all_score(self, bpe_head_ent_emb, bpe_rel_ent_emb, E): # (2) Construct multi-vector in Cl_{p,q} (\mathbb{R}^d) for head entities and relations h0, hp, hq = self.construct_cl_multivector(bpe_head_ent_emb, r=self.r, p=self.p, q=self.q) r0, rp, rq = self.construct_cl_multivector(bpe_rel_ent_emb, r=self.r, p=self.p, q=self.q) hp, hq, rp, rq = self.apply_coefficients(hp, hq, rp, rq) # (3.1) Extract real part t0 = E[:, :self.r] num_entities = len(E) # (4) Compute a triple score based on interactions described by the basis 1. Eq. 20 h0r0t0 = torch.einsum('br,er->be', h0 * r0, t0) # (5) Compute a triple score based on interactions described by the bases of p {e_1, ..., e_p}. Eq. 21 if self.p > 0: tp = E[:, self.r: self.r + (self.r * self.p)].view(num_entities, self.r, self.p) hp_rp_t0 = torch.einsum('brp, er -> be', hp * rp, t0) h0_rp_tp = torch.einsum('brp, erp -> be', torch.einsum('br, brp -> brp', h0, rp), tp) hp_r0_tp = torch.einsum('brp, erp -> be', torch.einsum('brp, br -> brp', hp, r0), tp) score_p = hp_rp_t0 + h0_rp_tp + hp_r0_tp else: score_p = 0 # (5) Compute a triple score based on interactions described by the bases of q {e_{p+1}, ..., e_{p+q}}. Eq. 22 if self.q > 0: tq = E[:, -(self.r * self.q):].view(num_entities, self.r, self.q) h0_rq_tq = torch.einsum('brq, erq -> be', torch.einsum('br, brq -> brq', h0, rq), tq) hq_r0_tq = torch.einsum('brq, erq -> be', torch.einsum('brq, br -> brq', hq, r0), tq) hq_rq_t0 = torch.einsum('brq, er -> be', hq * rq, t0) score_q = h0_rq_tq + hq_r0_tq - hq_rq_t0 else: score_q = 0 if self.p >= 2: sigma_pp = torch.sum(self.compute_sigma_pp(hp, rp), dim=[1, 2]).unsqueeze(-1) else: sigma_pp = 0 if self.q >= 2: sigma_qq = torch.sum(self.compute_sigma_qq(hq, rq), dim=[1, 2]).unsqueeze(-1) else: sigma_qq = 0 if self.p >= 2 and self.q >= 2: sigma_pq = torch.sum(self.compute_sigma_pq(hp=hp, hq=hq, rp=rp, rq=rq), dim=[1, 2, 3]).unsqueeze(-1) else: sigma_pq = 0 return h0r0t0 + score_p + score_q + sigma_pp + sigma_qq + sigma_pq
[docs] def forward_k_vs_all(self, x: torch.Tensor) -> torch.FloatTensor: """ Kvsall training (1) Retrieve real-valued embedding vectors for heads and relations \mathbb{R}^d . (2) Construct head entity and relation embeddings according to Cl_{p,q}(\mathbb{R}^d) . (3) Perform Cl multiplication (4) Inner product of (3) and all entity embeddings forward_k_vs_with_explicit and this funcitons are identical Parameter --------- x: torch.LongTensor with (n,2) shape Returns ------- torch.FloatTensor with (n, |E|) shape """ # (1) Retrieve real-valued embedding vectors. head_ent_emb, rel_ent_emb = self.get_head_relation_representation(x) # (3) Extract all entity embeddings E = self.entity_embeddings.weight return self.k_vs_all_score(head_ent_emb, rel_ent_emb, E)
[docs] def construct_batch_selected_cl_multivector(self, x: torch.FloatTensor, r: int, p: int, q: int) -> tuple[ torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]: """ Construct a batch of batchs multivectors Cl_{p,q}(\mathbb{R}^d) Parameter --------- x: torch.FloatTensor with (n,k, d) shape Returns ------- a0: torch.FloatTensor with (n,k, m) shape ap: torch.FloatTensor with (n,k, m, p) shape aq: torch.FloatTensor with (n,k, m, q) shape """ batch_size, k, d = x.shape # (1) Take the first m columns for each k a0 = x[:, :, :r].view(batch_size, k, r) # (2) B_{n \times p}, C_{n \times q}: take the self.k * self.p columns after the k. column if p > 0: ap = x[:, :, r: r + (r * p)].view(batch_size, k, r, p) else: ap = torch.zeros((batch_size, k, r, p), device=self.device) if q > 0: # (3) B_{n \times p}, C_{n \times q}: take the last self.r * self.q . aq = x[:, :, -(r * q):].view(batch_size, k, r, q) else: aq = torch.zeros((batch_size, k, r, q), device=self.device) return a0, ap, aq
[docs] def forward_k_vs_sample(self, x: torch.LongTensor, target_entity_idx: torch.LongTensor) -> torch.FloatTensor: """ Parameter --------- x: torch.LongTensor with (n,2) shape target_entity_idx: torch.LongTensor with (n, k ) shape k denotes the selected number of examples. Returns ------- torch.FloatTensor with (n, k) shape """ # (1) Retrieve real-valued embedding vectors. # (b, d), (b, d) head_ent_emb, rel_ent_emb = self.get_head_relation_representation(x) # (2) Construct multi-vector embeddings in Cl_{p,q} (\mathbb{R}^d) for head entities and relations # (b, m), (b, m, p), (b, m, q) h0, hp, hq = self.construct_cl_multivector(head_ent_emb, r=self.r, p=self.p, q=self.q) # (b, m), (b, m, p), (b, m, q) r0, rp, rq = self.construct_cl_multivector(rel_ent_emb, r=self.r, p=self.p, q=self.q) hp, hq, rp, rq = self.apply_coefficients(hp, hq, rp, rq) # (3) (b, k, d) Retrieve real-valued embedding vectors of selected entities. E = self.entity_embeddings(target_entity_idx) # (4) Construct multi-vector embeddings in Cl_{p,q} (\mathbb{R}^d) for head entities and relations # (b, k, m), (b, k, m, p), (b, k, m, q) t0, tp, tq = self.construct_batch_selected_cl_multivector(E, r=self.r, p=self.p, q=self.q) # (4) Batch vector matrix multiplications # Equivalent computations # h0*r0@t0.transpose(1,2) # torch.einsum('bm, bmk -> bk', h0 * r0, t0.transpose(1, 2)) # torch.einsum('bm, bkm -> bk', h0 * r0, t0) h0r0t0 = torch.einsum('bm, bkm -> bk', h0 * r0, t0) # (5) Compute a triple score based on interactions described by the bases of p {e_1, ..., e_p}. Eq. 21 if self.p > 0: raise NotImplementedError("Sample with p>0 for Keci not implemented") """ # Second term in Eq.16 hp_rp_t0 = torch.einsum('brp, br -> b', hp * rp, t0) # Eq. 17 # b=e h0_rp_tp = torch.einsum('brp, erp -> b', torch.einsum('br, brp -> brp', h0, rp), tp) hp_r0_tp = torch.einsum('brp, erp -> b', torch.einsum('brp, br -> brp', hp, r0), tp) score_p = hp_rp_t0 + h0_rp_tp + hp_r0_tp """ else: score_p = 0 # (6) Compute a triple score based on interactions described by the bases of q {e_{p+1}, ..., e_{p+q}}. Eq. 22 if self.q > 0: # \sum_{j=p+1}^{p+q} (h_j r_j t_0) : Third parth of the in Eq 16. # Equivalent computation # torch.einsum('bmq, bkm -> bk', hq*rq, t0) => (hq * rq).transpose(1,2) @ t0.transpose(1,2) hq_rq_t0 = torch.einsum('bmq, bkm -> bk', hq * rq, t0) # Eq. 18. Batch elementwise matrix matrix multiplication: bmq -> bkmq rq_tq=torch.einsum('bmq, bkmq -> bkmq', rq, tq) h0_rq_tq = torch.einsum('bm, bkmq -> bk', h0, rq_tq) hq_tq=torch.einsum('bmq, bkmq -> bkmq',hq, tq) r0_hq_tq = torch.einsum('bm, bkmq -> bk', r0, hq_tq) score_q = - hq_rq_t0 + (h0_rq_tq + r0_hq_tq) else: score_q = 0 if self.p >= 2: sigma_pp = torch.sum(self.compute_sigma_pp(hp, rp), dim=[1, 2]).unsqueeze(-1) else: sigma_pp = 0 if self.q >= 2: sigma_qq = torch.sum(self.compute_sigma_qq(hq, rq), dim=[1, 2]).unsqueeze(-1) else: sigma_qq = 0 if self.p >= 2 and self.q >= 2: sigma_pq = torch.sum(self.compute_sigma_pq(hp=hp, hq=hq, rp=rp, rq=rq), dim=[1, 2, 3]).unsqueeze(-1) else: sigma_pq = 0 return h0r0t0 + score_p + score_q + sigma_pp + sigma_qq + sigma_pq
[docs] def score(self, h, r, t): # (2) Construct multi-vector in Cl_{p,q} (\mathbb{R}^d) for head entities and relations h0, hp, hq = self.construct_cl_multivector(h, r=self.r, p=self.p, q=self.q) r0, rp, rq = self.construct_cl_multivector(r, r=self.r, p=self.p, q=self.q) t0, tp, tq = self.construct_cl_multivector(t, r=self.r, p=self.p, q=self.q) if self.q > 0: self.q_coefficients = self.q_coefficients.to(h0.device, non_blocking=True) hp, hq, rp, rq = self.apply_coefficients(hp, hq, rp, rq) # (4) Compute a triple score based on interactions described by the basis 1. Eq. 20 h0r0t0 = torch.einsum('br, br -> b', h0 * r0, t0) # (5) Compute a triple score based on interactions described by the bases of p {e_1, ..., e_p}. Eq. 21 if self.p > 0: # Second term in Eq.16 hp_rp_t0 = torch.einsum('brp, br -> b', hp * rp, t0) # Eq. 17 # b=e h0_rp_tp = torch.einsum('brp, erp -> b', torch.einsum('br, brp -> brp', h0, rp), tp) hp_r0_tp = torch.einsum('brp, erp -> b', torch.einsum('brp, br -> brp', hp, r0), tp) score_p = hp_rp_t0 + h0_rp_tp + hp_r0_tp else: score_p = 0 # (5) Compute a triple score based on interactions described by the bases of q {e_{p+1}, ..., e_{p+q}}. Eq. 22 if self.q > 0: # Third item in Eq 16. hq_rq_t0 = torch.einsum('brq, br -> b', hq * rq, t0) # Eq. 18. h0_rq_tq = torch.einsum('br, brq -> b', h0, rq * tq) r0_hq_tq = torch.einsum('br, brq -> b', r0, hq * tq) score_q = - hq_rq_t0 + (h0_rq_tq + r0_hq_tq) else: score_q = 0 if self.p >= 2: sigma_pp = torch.sum(self.compute_sigma_pp(hp, rp), dim=[1, 2]).unsqueeze(-1) else: sigma_pp = 0 if self.q >= 2: sigma_qq = torch.sum(self.compute_sigma_qq(hq, rq), dim=[1, 2]).unsqueeze(-1) else: sigma_qq = 0 if self.p >= 2 and self.q >= 2: sigma_pq = torch.sum(self.compute_sigma_pq(hp=hp, hq=hq, rp=rp, rq=rq), dim=[1, 2, 3]).unsqueeze(-1) else: sigma_pq = 0 return h0r0t0 + score_p + score_q + sigma_pp + sigma_qq + sigma_pq
[docs] def forward_triples(self, x: torch.Tensor) -> torch.FloatTensor: """ Parameter --------- x: torch.LongTensor with (n,3) shape Returns ------- torch.FloatTensor with (n) shape """ # (1) Retrieve real-valued embedding vectors. head_ent_emb, rel_ent_emb, tail_ent_emb = self.get_triple_representation(x) # (2) Construct multi-vector in Cl_{p,q} (\mathbb{R}^d) for head entities and relations h0, hp, hq = self.construct_cl_multivector(head_ent_emb, r=self.r, p=self.p, q=self.q) r0, rp, rq = self.construct_cl_multivector(rel_ent_emb, r=self.r, p=self.p, q=self.q) t0, tp, tq = self.construct_cl_multivector(tail_ent_emb, r=self.r, p=self.p, q=self.q) hp, hq, rp, rq = self.apply_coefficients( hp, hq, rp, rq) # (4) Compute a triple score based on interactions described by the basis 1. Eq. 20 h0r0t0 = torch.einsum('br, br -> b', h0 * r0, t0) # (5) Compute a triple score based on interactions described by the bases of p {e_1, ..., e_p}. Eq. 21 if self.p > 0: # Second term in Eq.16 hp_rp_t0 = torch.einsum('brp, br -> b', hp * rp, t0) # Eq. 17 # b=e h0_rp_tp = torch.einsum('brp, erp -> b', torch.einsum('br, brp -> brp', h0, rp), tp) hp_r0_tp = torch.einsum('brp, erp -> b', torch.einsum('brp, br -> brp', hp, r0), tp) score_p = hp_rp_t0 + h0_rp_tp + hp_r0_tp else: score_p = 0 # (5) Compute a triple score based on interactions described by the bases of q {e_{p+1}, ..., e_{p+q}}. Eq. 22 if self.q > 0: # Third item in Eq 16. hq_rq_t0 = torch.einsum('brq, br -> b', hq * rq, t0) # Eq. 18. h0_rq_tq = torch.einsum('br, brq -> b', h0, rq * tq) r0_hq_tq = torch.einsum('br, brq -> b', r0, hq * tq) score_q = - hq_rq_t0 + (h0_rq_tq + r0_hq_tq) else: score_q = 0 if self.p >= 2: sigma_pp = torch.sum(self.compute_sigma_pp(hp, rp), dim=[1, 2]).unsqueeze(-1) else: sigma_pp = 0 if self.q >= 2: sigma_qq = torch.sum(self.compute_sigma_qq(hq, rq), dim=[1, 2]).unsqueeze(-1) else: sigma_qq = 0 if self.p >= 2 and self.q >= 2: sigma_pq = torch.sum(self.compute_sigma_pq(hp=hp, hq=hq, rp=rp, rq=rq), dim=[1, 2, 3]).unsqueeze(-1) else: sigma_pq = 0 return h0r0t0 + score_p + score_q + sigma_pp + sigma_qq + sigma_pq
[docs] class KeciBase(Keci): " Without learning dimension scaling" def __init__(self, args): super().__init__(args) self.name = 'KeciBase' self.requires_grad_for_interactions = False print(f'r:{self.r}\t p:{self.p}\t q:{self.q}') if self.p > 0: self.p_coefficients = torch.nn.Embedding(num_embeddings=1, embedding_dim=self.p) torch.nn.init.ones_(self.p_coefficients.weight) if self.q > 0: self.q_coefficients = torch.nn.Embedding(num_embeddings=1, embedding_dim=self.q) torch.nn.init.ones_(self.q_coefficients.weight)
[docs] class DeCaL(BaseKGE): def __init__(self, args): super().__init__(args) self.name = 'DeCaL' self.entity_embeddings = torch.nn.Embedding(self.num_entities, self.embedding_dim) self.relation_embeddings = torch.nn.Embedding(self.num_relations, self.embedding_dim) self.p = self.args.get("p", 0) self.q = self.args.get("q", 0) self.r = self.args.get("r", 0) self.re = int(self.embedding_dim / (self.r + self.p + self.q + 1)) # Initialize parameters for dimension scaling if self.p > 0: self.p_coefficients = torch.nn.Embedding(num_embeddings=1, embedding_dim=self.p) torch.nn.init.zeros_(self.p_coefficients.weight) if self.q > 0: self.q_coefficients = torch.nn.Embedding(num_embeddings=1, embedding_dim=self.q) torch.nn.init.zeros_(self.q_coefficients.weight) if self.r > 0: self.r_coefficients = torch.nn.Embedding(num_embeddings=1, embedding_dim=self.r) torch.nn.init.zeros_(self.r_coefficients.weight)
[docs] def forward_triples(self, x: torch.Tensor) -> torch.FloatTensor: """ Parameter --------- x: torch.LongTensor with (n, ) shape Returns ------- torch.FloatTensor with (n) shape """ # (1) Retrieve real-valued embedding vectors. head_ent_emb, rel_ent_emb, tail_ent_emb = self.get_triple_representation(x) # (2) Construct multi-vector in Cl_{p,q,r} (\mathbb{R}^d) for head entities and relations h0, hp, hq, hk = self.construct_cl_multivector(head_ent_emb, re=self.re, p=self.p, q=self.q, r=self.r) r0, rp, rq, rk = self.construct_cl_multivector(rel_ent_emb, re=self.re, p=self.p, q=self.q, r=self.r) t0, tp, tq, tk = self.construct_cl_multivector(tail_ent_emb, re=self.re, p=self.p, q=self.q, r=self.r) # h0, hp, hq, hk, h0, rp, rq, rk = self.apply_coefficients(h0, hp, hq, hk, h0, rp, rq,rk) # (4) Compute a triple score based on interactions described by the basis 1. h0r0t0 = torch.einsum('br, br -> b', h0 * r0, t0) # (5) Compute a triple score based on interactions described by the bases of p {e_1, ..., e_p}. if self.p > 0: # Second term in Eq.16 hp_rp_t0 = torch.einsum('brp, br -> b', hp * rp, t0) # Eq. 17 # b=e h0_rp_tp = torch.einsum('brp, erp -> b', torch.einsum('br, brp -> brp', h0, rp), tp) hp_r0_tp = torch.einsum('brp, erp -> b', torch.einsum('brp, br -> brp', hp, r0), tp) score_p = hp_rp_t0 + h0_rp_tp + hp_r0_tp else: score_p = 0 # (5) Compute a triple score based on interactions described by the bases of q {e_{p+1}, ..., e_{p+q}}. Eq. 22 if self.q > 0: # Third item in Eq 16. hq_rq_t0 = torch.einsum('brq, br -> b', hq * rq, t0) # Eq. 18. h0_rq_tq = torch.einsum('br, brq -> b', h0, rq * tq) r0_hq_tq = torch.einsum('br, brq -> b', r0, hq * tq) score_q = - hq_rq_t0 + (h0_rq_tq + r0_hq_tq) else: score_q = 0 if self.r > 0: # Eq. 18. h0_rk_tk = torch.einsum('br, brk -> b', h0, rk * tk) r0_hk_tk = torch.einsum('br, brk -> b', r0, hk * tk) score_r = (h0_rk_tk + r0_hk_tk) else: score_r = 0 if self.p >= 2: sigma_pp = torch.sum(self.compute_sigma_pp(hp, rp), dim=[1, 2]).unsqueeze(-1) else: sigma_pp = 0 if self.q >= 2: sigma_qq = torch.sum(self.compute_sigma_qq(hq, rq), dim=[1, 2]).unsqueeze(-1) else: sigma_qq = 0 if self.r >= 2: sigma_rr = torch.sum(self.compute_sigma_qq(hk, rk), dim=[1, 2]).unsqueeze(-1) else: sigma_rr = 0 if self.p >= 2 and self.q >= 2: sigma_pq = torch.sum(self.compute_sigma_pq(hp=hp, hq=hq, rp=rp, rq=rq), dim=[1, 2, 3]).unsqueeze(-1) else: sigma_pq = 0 if self.p >= 2 and self.r >= 2: sigma_pr = torch.sum(self.compute_sigma_pq(hp=hp, hk=hk, rp=rp, rk=rk), dim=[1, 2, 3]).unsqueeze(-1) else: sigma_pr = 0 if self.q >= 2 and self.r >= 2: sigma_qr = torch.sum(self.compute_sigma_pq(hq=hq, hk=hk, rq=rq, rk=rk), dim=[1, 2, 3]).unsqueeze(-1) else: sigma_qr = 0 return h0r0t0 + score_p + score_q + score_r + sigma_pp + sigma_qq + sigma_rr + sigma_pq + sigma_qr + sigma_pr
[docs] def cl_pqr(self, a:torch.tensor)->torch.tensor: ''' Input: tensor(batch_size, emb_dim) ---> output: tensor with 1+p+q+r components with size (batch_size, emb_dim/(1+p+q+r)) each. 1) takes a tensor of size (batch_size, emb_dim), split it into 1 + p + q +r components, hence 1+p+q+r must be a divisor of the emb_dim. 2) Return a list of the 1+p+q+r components vectors, each are tensors of size (batch_size, emb_dim/(1+p+q+r)) ''' # num1 = 2**(p+q+r) #total number of vector in cl_pqr then after choose the first p+q+r+1 vectors num1 = 1 + self.p + self.q + self.r a1 = torch.hsplit(a, num1) return torch.stack(a1)
[docs] def compute_sigmas_single(self, list_h_emb, list_r_emb, list_t_emb): '''here we compute all the sums with no others vectors interaction taken with the scalar product with t, that is, .. math:: s0 = h_0r_0t_0 s1 = \sum_{i=1}^{p}h_ir_it_0 s2 = \sum_{j=p+1}^{p+q}h_jr_jt_0 s3 = \sum_{i=1}^{q}(h_0r_it_i + h_ir_0t_i) s4 = \sum_{i=p+1}^{p+q}(h_0r_it_i + h_ir_0t_i) s5 = \sum_{i=p+q+1}^{p+q+r}(h_0r_it_i + h_ir_0t_i) and return: .. math:: sigma_0t = \sigma_0 \cdot t_0 = s0 + s1 -s2 s3, s4 and s5 ''' p = self.p q = self.q r = self.r h_0 = list_h_emb[0] # h_i = list_h_emb[i] similarly for r and t r_0 = list_r_emb[0] t_0 = list_t_emb[0] s0 = (h_0 * r_0 * t_0).sum(dim=1) s1 = (t_0 * (list_h_emb[1:p + 1] * list_r_emb[1:p + 1])).sum(dim=[-1, 0]) s2 = (t_0 * (list_h_emb[p + 1:p + q + 1] * list_r_emb[p + 1:p + q + 1])).sum(dim=[-1, 0]) s3 = (h_0 * (list_r_emb[1:p + 1] * list_t_emb[1:p + 1]) + r_0 * ( list_h_emb[1:p + 1] * list_t_emb[1:p + 1])).sum(dim=[-1, 0]) s4 = (h_0 * (list_r_emb[p + 1:p + q + 1] * list_t_emb[p + 1:p + q + 1]) + r_0 * ( list_h_emb[p + 1:p + q + 1] * list_t_emb[p + 1:p + q + 1])).sum(dim=[-1, 0]) s5 = (h_0 * (list_r_emb[p + q + 1:p + q + r + 1] * list_t_emb[p + q + 1:p + q + r + 1]) + r_0 * ( list_h_emb[p + q + 1:p + q + r + 1] * list_t_emb[p + q + 1:p + q + r + 1])).sum(dim=[-1, 0]) sigma_0t = s0 + s1 - s2 return sigma_0t, s3, s4, s5
[docs] def compute_sigmas_multivect(self, list_h_emb, list_r_emb): '''Here we compute and return all the sums with vectors interaction for the same and different bases. For same bases vectors interaction we have .. math:: \sigma_pp = \sum_{i=1}^{p-1}\sum_{i'=i+1}^{p}(h_ir_{i'}-h_{i'}r_i) (models the interactions between e_i and e_i' for 1 <= i, i' <= p) \sigma_qq = \sum_{j=p+1}^{p+q-1}\sum_{j'=j+1}^{p+q}(h_jr_{j'}-h_{j'} (models the interactions between e_j and e_j' for p+1 <= j, j' <= p+q) \sigma_rr = \sum_{k=p+q+1}^{p+q+r-1}\sum_{k'=k+1}^{p}(h_kr_{k'}-h_{k'}r_k) (models the interactions between e_k and e_k' for p+q+1 <= k, k' <= p+q+r) For different base vector interactions, we have .. math:: \sigma_pq = \sum_{i=1}^{p}\sum_{j=p+1}^{p+q}(h_ir_j - h_jr_i) (interactionsn between e_i and e_j for 1<=i <=p and p+1<= j <= p+q) \sigma_pr = \sum_{i=1}^{p}\sum_{k=p+q+1}^{p+q+r}(h_ir_k - h_kr_i) (interactionsn between e_i and e_k for 1<=i <=p and p+q+1<= k <= p+q+r) \sigma_qr = \sum_{j=p+1}^{p+q}\sum_{j=p+q+1}^{p+q+r}(h_jr_k - h_kr_j) (interactionsn between e_j and e_k for p+1 <= j <=p+q and p+q+1<= j <= p+q+r) ''' p = self.p q = self.q r = self.r if p > 0: indices_i = torch.arange(1, p) sigma_pp = ((list_h_emb[indices_i] * list_r_emb[indices_i + 1].sum(dim=0)) - ( list_h_emb[indices_i + 1].sum(dim=0) * list_r_emb[indices_i])).sum(dim=[-1, 0]) else: indices_i = [] sigma_pp = 0 if q > 0: indices_j = torch.arange(p + 1, p + q) sigma_qq = ((list_h_emb[indices_j] * list_r_emb[indices_j + 1].sum(dim=0)) - ( list_h_emb[indices_j + 1].sum(dim=0) * list_r_emb[indices_j])).sum(dim=[-1, 0]) else: indices_j = [] sigma_qq = 0 if r > 0: indices_k = torch.arange(p + q + 1, p + q + r) sigma_rr = ((list_h_emb[indices_k] * list_r_emb[indices_k + 1].sum(dim=0)) - ( list_h_emb[indices_k + 1].sum(dim=0) * list_r_emb[indices_k])).sum(dim=[-1, 0]) else: indices_k = [] sigma_rr = 0 sigma_pq = ((list_h_emb[indices_i] * list_r_emb[indices_j].sum(dim=0)) - ( list_h_emb[indices_j].sum(dim=0) * list_r_emb[indices_i])).sum(dim=[-1, 0]) sigma_pr = ((list_h_emb[indices_i] * list_r_emb[indices_k].sum(dim=0)) - ( list_h_emb[indices_k].sum(dim=0) * list_r_emb[indices_i])).sum(dim=[-1, 0]) sigma_qr = ((list_h_emb[indices_j] * list_r_emb[indices_k].sum(dim=0)) - ( list_h_emb[indices_k].sum(dim=0) * list_r_emb[indices_j])).sum(dim=[-1, 0]) return sigma_pp, sigma_qq, sigma_rr, sigma_pq, sigma_pr, sigma_qr
[docs] def forward_k_vs_all(self, x: torch.Tensor) -> torch.FloatTensor: """ Kvsall training (1) Retrieve real-valued embedding vectors for heads and relations (2) Construct head entity and relation embeddings according to Cl_{p,q, r}(\mathbb{R}^d) . (3) Perform Cl multiplication (4) Inner product of (3) and all entity embeddings forward_k_vs_with_explicit and this funcitons are identical Parameter --------- x: torch.LongTensor with (n, ) shape Returns ------- torch.FloatTensor with (n, |E|) shape """ # (1) Retrieve real-valued embedding vectors. head_ent_emb, rel_ent_emb = self.get_head_relation_representation(x) # (2) Construct multi-vector in Cl_{p,q} (\mathbb{R}^d) for head entities and relations h0, hp, hq, hk = self.construct_cl_multivector(head_ent_emb, re=self.re, p=self.p, q=self.q, r=self.r) r0, rp, rq, rk = self.construct_cl_multivector(rel_ent_emb, re=self.re, p=self.p, q=self.q, r=self.r) h0, hp, hq, hk, h0, rp, rq, rk = self.apply_coefficients(h0, hp, hq, hk, h0, rp, rq, rk) # (3) Extract all entity embeddings E = self.entity_embeddings.weight # (3.1) Extract real part t0 = E[:, :self.re] # (4) Compute a triple score based on interactions described by the basis 1. h0r0t0 = torch.einsum('br,er->be', h0 * r0, t0) # (5) Compute a triple score based on interactions described by the bases of p {e_1, ..., e_p}. if self.p > 0: tp = E[:, self.re: self.re + (self.re * self.p)].view(self.num_entities, self.re, self.p) hp_rp_t0 = torch.einsum('brp, er -> be', hp * rp, t0) h0_rp_tp = torch.einsum('brp, erp -> be', torch.einsum('br, brp -> brp', h0, rp), tp) hp_r0_tp = torch.einsum('brp, erp -> be', torch.einsum('brp, br -> brp', hp, r0), tp) score_p = hp_rp_t0 + h0_rp_tp + hp_r0_tp else: score_p = 0 # (5) Compute a triple score based on interactions described by the bases of q {e_{p+1}, ..., e_{p+q}}. if self.q > 0: num = self.re + (self.re * self.p) tq = E[:, num:num + (self.re * self.q)].view(self.num_entities, self.re, self.q) h0_rq_tq = torch.einsum('brq, erq -> be', torch.einsum('br, brq -> brq', h0, rq), tq) hq_r0_tq = torch.einsum('brq, erq -> be', torch.einsum('brq, br -> brq', hq, r0), tq) hq_rq_t0 = torch.einsum('brq, er -> be', hq * rq, t0) score_q = h0_rq_tq + hq_r0_tq - hq_rq_t0 else: score_q = 0 # (6) Compute a triple score based on interactions described by the bases of q {e_{p+q+1}, ..., e_{p+q+r}}. if self.r > 0: tk = E[:, -(self.re * self.r):].view(self.num_entities, self.re, self.r) h0_rk_tk = torch.einsum('brk, erk -> be', torch.einsum('br, brk -> brk', h0, rk), tk) hk_r0_tk = torch.einsum('brk, erk -> be', torch.einsum('brk, br -> brk', hk, r0), tk) # hq_rq_t0 = torch.einsum('brq, er -> be', hq * rq, t0) score_r = h0_rk_tk + hk_r0_tk else: score_r = 0 if self.p >= 2: sigma_pp = torch.sum(self.compute_sigma_pp(hp, rp), dim=[1, 2]).unsqueeze(-1) else: sigma_pp = 0 if self.q >= 2: sigma_qq = torch.sum(self.compute_sigma_qq(hq, rq), dim=[1, 2]).unsqueeze(-1) else: sigma_qq = 0 if self.r >= 2: sigma_rr = torch.sum(self.compute_sigma_rr(hk, rk), dim=[1, 2]).unsqueeze(-1) else: sigma_rr = 0 if self.p >= 2 and self.q >= 2: sigma_pq = torch.sum(self.compute_sigma_pq(hp=hp, hq=hq, rp=rp, rq=rq), dim=[1, 2, 3]).unsqueeze(-1) else: sigma_pq = 0 if self.p >= 2 and self.r >= 2: sigma_pr = torch.sum(self.compute_sigma_pr(hp=hp, hk=hk, rp=rp, rk=rk), dim=[1, 2, 3]).unsqueeze(-1) else: sigma_pr = 0 if self.q >= 2 and self.r >= 2: sigma_qr = torch.sum(self.compute_sigma_qr(hq=hq, hk=hk, rq=rq, rk=rk), dim=[1, 2, 3]).unsqueeze(-1) else: sigma_qr = 0 return h0r0t0 + score_p + score_q + score_r + sigma_pp + sigma_qq + sigma_rr + sigma_pq + sigma_pr + sigma_qr
[docs] def apply_coefficients(self, h0, hp, hq, hk, r0, rp, rq, rk): """ Multiplying a base vector with its scalar coefficient """ if self.p > 0: hp = hp * self.p_coefficients.weight rp = rp * self.p_coefficients.weight if self.q > 0: hq = hq * self.q_coefficients.weight rq = rq * self.q_coefficients.weight if self.r > 0: hk = hk * self.r_coefficients.weight rk = rk * self.r_coefficients.weight return h0, hp, hq, hk, r0, rp, rq, rk
[docs] def construct_cl_multivector(self, x: torch.FloatTensor, re: int, p: int, q: int, r: int) -> tuple[ torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]: """ Construct a batch of multivectors Cl_{p,q,r}(\mathbb{R}^d) Parameter --------- x: torch.FloatTensor with (n,d) shape Returns ------- a0: torch.FloatTensor ap: torch.FloatTensor aq: torch.FloatTensor ar: torch.FloatTensor """ batch_size, d = x.shape # (1) A_{n \times k}: take the first k columns a0 = x[:, :re].view(batch_size, re) # (2) B_{n \times p}, C_{n \times q}: take the self.k * self.p columns after the k. column if p > 0: ap = x[:, re: re + (re * p)].view(batch_size, re, p) else: ap = torch.zeros((batch_size, re, p), device=self.device) if q > 0: # (3) B_{n \times p}, C_{n \times q}: take the last self.r * self.q . aq = x[:, re + (re * p):re + (re * p) + (re * q):].view(batch_size, re, q) else: aq = torch.zeros((batch_size, re, q), device=self.device) if r > 0: # (3) B_{n \times p}, C_{n \times q}: take the last self.r * self.q . ar = x[:, -(re * r):].view(batch_size, re, r) else: ar = torch.zeros((batch_size, re, r), device=self.device) return a0, ap, aq, ar
[docs] def compute_sigma_pp(self, hp, rp): """ Compute .. math:: \sigma_{p,p}^* = \sum_{i=1}^{p-1}\sum_{i'=i+1}^{p}(x_iy_{i'}-x_{i'}y_i) \sigma_{pp} captures the interactions between along p bases For instance, let p e_1, e_2, e_3, we compute interactions between e_1 e_2, e_1 e_3 , and e_2 e_3 This can be implemented with a nested two for loops results = [] for i in range(p - 1): for k in range(i + 1, p): results.append(hp[:, :, i] * rp[:, :, k] - hp[:, :, k] * rp[:, :, i]) sigma_pp = torch.stack(results, dim=2) assert sigma_pp.shape == (b, r, int((p * (p - 1)) / 2)) Yet, this computation would be quite inefficient. Instead, we compute interactions along all p, e.g., e1e1, e1e2, e1e3, e2e1, e2e2, e2e3, e3e1, e3e2, e3e3 Then select the triangular matrix without diagonals: e1e2, e1e3, e2e3. """ # Compute indexes for the upper triangle of p by p matrix indices = torch.triu_indices(self.p, self.p, offset=1) # Compute p by p operations sigma_pp = torch.einsum('nrp,nrx->nrpx', hp, rp) - torch.einsum('nrx,nrp->nrpx', hp, rp) sigma_pp = sigma_pp[:, :, indices[0], indices[1]] return sigma_pp
[docs] def compute_sigma_qq(self, hq, rq): """ Compute .. math:: \sigma_{q,q}^* = \sum_{j=p+1}^{p+q-1}\sum_{j'=j+1}^{p+q}(x_jy_{j'}-x_{j'}y_j) Eq. 16 sigma_{q} captures the interactions between along q bases For instance, let q e_1, e_2, e_3, we compute interactions between e_1 e_2, e_1 e_3 , and e_2 e_3 This can be implemented with a nested two for loops results = [] for j in range(q - 1): for k in range(j + 1, q): results.append(hq[:, :, j] * rq[:, :, k] - hq[:, :, k] * rq[:, :, j]) sigma_qq = torch.stack(results, dim=2) assert sigma_qq.shape == (b, r, int((q * (q - 1)) / 2)) Yet, this computation would be quite inefficient. Instead, we compute interactions along all p, e.g., e1e1, e1e2, e1e3, e2e1, e2e2, e2e3, e3e1, e3e2, e3e3 Then select the triangular matrix without diagonals: e1e2, e1e3, e2e3. """ # Compute indexes for the upper triangle of p by p matrix if self.q > 1: indices = torch.triu_indices(self.q, self.q, offset=1) # Compute p by p operations sigma_qq = torch.einsum('nrp,nrx->nrpx', hq, rq) - torch.einsum('nrx,nrp->nrpx', hq, rq) sigma_qq = sigma_qq[:, :, indices[0], indices[1]] else: sigma_qq = torch.zeros((len(hq), self.re, int((self.q * (self.q - 1)) / 2))) return sigma_qq
[docs] def compute_sigma_rr(self, hk, rk): """ .. math:: \sigma_{r,r}^* = \sum_{k=p+q+1}^{p+q+r-1}\sum_{k'=k+1}^{p}(x_ky_{k'}-x_{k'}y_k) """ # Compute indexes for the upper triangle of p by p matrix if self.r > 1: indices = torch.triu_indices(self.r, self.r, offset=1) # Compute p by p operations sigma_rr = torch.einsum('nrp,nrx->nrpx', hk, rk) - torch.einsum('nrx,nrp->nrpx', hk, rk) sigma_rr = sigma_rr[:, :, indices[0], indices[1]] else: sigma_rr = torch.zeros((len(hk), self.re, int((self.r * (self.r - 1)) / 2))) return sigma_rr
[docs] def compute_sigma_pq(self, *, hp, hq, rp, rq): """ Compute .. math:: \sum_{i=1}^{p} \sum_{j=p+1}^{p+q} (h_i r_j - h_j r_i) e_i e_j results = [] sigma_pq = torch.zeros(b, r, p, q) for i in range(p): for j in range(q): sigma_pq[:, :, i, j] = hp[:, :, i] * rq[:, :, j] - hq[:, :, j] * rp[:, :, i] print(sigma_pq.shape) """ sigma_pq = torch.einsum('nrp,nrq->nrpq', hp, rq) - torch.einsum('nrq,nrp->nrpq', hq, rp) assert sigma_pq.shape[1:] == (self.re, self.p, self.q) return sigma_pq
[docs] def compute_sigma_pr(self, *, hp, hk, rp, rk): """ Compute .. math:: \sum_{i=1}^{p} \sum_{j=p+1}^{p+q} (h_i r_j - h_j r_i) e_i e_j results = [] sigma_pq = torch.zeros(b, r, p, q) for i in range(p): for j in range(q): sigma_pq[:, :, i, j] = hp[:, :, i] * rq[:, :, j] - hq[:, :, j] * rp[:, :, i] print(sigma_pq.shape) """ sigma_pr = torch.einsum('nrp,nrk->nrpk', hp, rk) - torch.einsum('nrk,nrp->nrpk', hk, rp) assert sigma_pr.shape[1:] == (self.re, self.p, self.r) return sigma_pr
[docs] def compute_sigma_qr(self, *, hq, hk, rq, rk): """ .. math:: \sum_{i=1}^{p} \sum_{j=p+1}^{p+q} (h_i r_j - h_j r_i) e_i e_j results = [] sigma_pq = torch.zeros(b, r, p, q) for i in range(p): for j in range(q): sigma_pq[:, :, i, j] = hp[:, :, i] * rq[:, :, j] - hq[:, :, j] * rp[:, :, i] print(sigma_pq.shape) """ sigma_qr = torch.einsum('nrq,nrk->nrqk', hq, rk) - torch.einsum('nrk,nrq->nrqk', hk, rq) assert sigma_qr.shape[1:] == (self.re, self.q, self.r) return sigma_qr