mirror of
https://github.com/wassname/alpaca_convert.git
synced 2026-06-27 16:14:08 +08:00
Fix args of flash attention
This commit is contained in:
@@ -4,7 +4,7 @@ import torch
|
||||
from torch import nn
|
||||
|
||||
import transformers
|
||||
from transformers.models.llama.modeling_llama import LlamaRotaryEmbedding, apply_rotary_pos_emb
|
||||
from transformers.models.llama.modeling_llama import LlamaConfig, LlamaRotaryEmbedding, apply_rotary_pos_emb
|
||||
|
||||
from einops import rearrange
|
||||
|
||||
@@ -16,13 +16,14 @@ class LlamaAttention(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size: int,
|
||||
num_heads: int,
|
||||
config: LlamaConfig,
|
||||
):
|
||||
super().__init__()
|
||||
hidden_size = config.hidden_size
|
||||
num_heads = config.num_attention_heads
|
||||
self.hidden_size = hidden_size
|
||||
self.num_heads = num_heads
|
||||
self.head_dim = hidden_size // num_heads
|
||||
self.head_dim = self.hidden_size // num_heads
|
||||
|
||||
if (self.head_dim * num_heads) != self.hidden_size:
|
||||
raise ValueError(
|
||||
|
||||
Reference in New Issue
Block a user