Source code for target_extraction.allen.modules.word_dropout

import torch
from torch.nn.modules import Dropout2d

[docs]class WordDrouput(torch.nn.Module): ''' Word Dropout will randomly drop whole words/timesteps. This is equivalent to `1D Spatial Dropout`_. .. _1D Spatial Dropout:https://keras.io/layers/core/#spatialdropout1d ''' def __init__(self, p: float) -> None: ''' :param p: probability of a whole word/timestep to be zeroed/dropped. ''' super().__init__() self._word_dropout = Dropout2d(p)
[docs] def forward(self, embedded_text: torch.FloatTensor) -> torch.FloatTensor: ''' :param embedded_text: A tensor of shape: [batch_size, timestep, embedding_dim] of which the dropout will drop entire timestep which is the equivalent to words. :returns: The given tensor but with timesteps/words dropped. ''' embedded_text = embedded_text.unsqueeze(2) embedded_text = self._word_dropout(embedded_text) return embedded_text.squeeze(2)