some people get errors with view instead of reshape
This commit is contained in:
committed by
GitHub
parent
e7b68717f0
commit
5a58cd945a
@@ -131,7 +131,7 @@ class Attention(nn.Module):
|
||||
query, key, value = map(lambda t: einops.rearrange(t, 'B S H D -> B H S D'), (query, key, value)) # needed for scaled_dot_product_attention but not flash_attn_func
|
||||
attn_output = scaled_dot_product_attention(query=query, key=key, value=value, is_causal=self.causal)
|
||||
attn_output = einops.rearrange(attn_output, 'B H S D -> B S H D')
|
||||
attn_output = attn_output.view(batch_size, seq_len, self.output_size) # type: ignore
|
||||
attn_output = attn_output.reshape(batch_size, seq_len, self.output_size) # type: ignore
|
||||
return self.o_proj(attn_output)
|
||||
|
||||
class LinearSwish(nn.Module):
|
||||
|
||||
Reference in New Issue
Block a user