mirror of
https://github.com/wassname/activation_store.git
synced 2026-06-27 18:03:14 +08:00
another example
This commit is contained in:
@@ -18,14 +18,15 @@ from torch import Tensor
|
||||
default_output_folder = (Path(__file__).parent.parent / "outputs").resolve()
|
||||
|
||||
def default_postprocess_result(input: dict, trace: TraceDict, output: ModelOutput, model: AutoModelForCausalLM) -> Dict[str, Tensor]:
|
||||
"""add activations to output, and rearrange hidden states"""
|
||||
"""Make your own. This adds activations to output, and rearranges hidden states"""
|
||||
|
||||
# Baukit records the literal layer output, which varies by model. Here we assume that the output or the first part are activations we want
|
||||
acts = {f'act-{k}':
|
||||
v.output[0] if isinstance(v.output, tuple) else v.output
|
||||
for k, v in trace.items()}
|
||||
|
||||
output.hidden_states = rearrange(list(output.hidden_states), 'l b t h -> b l t h')
|
||||
# batch must be first, also the writer supports float16 so lets use that
|
||||
output.hidden_states = rearrange(list(output.hidden_states), 'l b t h -> b l t h').half()
|
||||
|
||||
return dict(attention_mask=input["attention_mask"], **acts, **output)
|
||||
|
||||
|
||||
+378
-2
@@ -181,7 +181,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 51,
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
@@ -224,7 +224,7 @@
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"f = activation_store(ds, model, layers=layers, writer_batch_size=10)\n",
|
||||
"f = activation_store(ds, model, layers=layers)\n",
|
||||
"f"
|
||||
]
|
||||
},
|
||||
@@ -292,6 +292,382 @@
|
||||
"source": [
|
||||
"ds_a[0:2]['act-model.layers.0.mlp.down_proj'].shape"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Get supressed activations"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 64,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from jaxtyping import Float, Int\n",
|
||||
"from torch import Tensor\n",
|
||||
"from einops import rearrange\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"def get_supressed_activations(\n",
|
||||
" hs: Float[Tensor, \"l b t h\"], w_out, w_inv\n",
|
||||
") -> Float[Tensor, \"l b t h\"]:\n",
|
||||
" \"\"\"\n",
|
||||
" Novel experiment: Here we define a transform to isolate supressed activations, where we hypothesis that style/concepts/scratchpads and other internal only representations must be stored.\n",
|
||||
"\n",
|
||||
" See the following references for more information:\n",
|
||||
"\n",
|
||||
" - https://arxiv.org/pdf/2401.12181\n",
|
||||
" - > Suppression neurons that are similar, except decrease the probability of a group of related tokens\n",
|
||||
"\n",
|
||||
" - https://arxiv.org/html/2406.19384\n",
|
||||
" - > Previous work suggests that networks contain ensembles of “prediction\" neurons, which act as probability promoters [66, 24, 32] and work in tandem with suppression neurons (Section 5.4).\n",
|
||||
"\n",
|
||||
" - https://arxiv.org/pdf/2401.12181\n",
|
||||
" > We find a striking pattern which is remarkably consistent across the different seeds: after about the halfway point in the model, prediction neurons become increasingly prevalent until the very end of the network where there is a sudden shift towards a much larger number of suppression neurons.\n",
|
||||
" \"\"\"\n",
|
||||
" with torch.no_grad():\n",
|
||||
" # here we pass the hs through the last layer, take a diff, and then project it back to find which activation changes lead to supressed\n",
|
||||
" hs2 = rearrange(hs[:, :, -1:], \"l b t h -> (l b t) h\")\n",
|
||||
" hs_out2 = torch.nn.functional.linear(hs2, w_out)\n",
|
||||
" hs_out = rearrange(\n",
|
||||
" hs_out2, \"(l b t) h -> l b t h\", l=hs.shape[0], b=hs.shape[1], t=1\n",
|
||||
" )\n",
|
||||
" diffs = hs_out[:, :, :].diff(dim=0)\n",
|
||||
" diffs2 = rearrange(diffs, \"l b t h -> (l b t) h\")\n",
|
||||
" # W_inv = get_cache_inv(w_out)\n",
|
||||
"\n",
|
||||
" diffs_inv2 = torch.nn.functional.linear(diffs2.to(dtype=w_inv.dtype), w_inv)\n",
|
||||
" diffs_inv = rearrange(\n",
|
||||
" diffs_inv2, \"(l b t) h -> l b t h\", l=hs.shape[0] - 1, b=hs.shape[1], t=1\n",
|
||||
" ).to(w_out.dtype)\n",
|
||||
" # TODO just return this?\n",
|
||||
" eps = 1.0e-1\n",
|
||||
" supressed_mask = (diffs_inv < -eps).to(hs.dtype)\n",
|
||||
" # supressed_mask = repeat(supressed_mask, 'l b 1 h -> l b t h', t=hs.shape[2])\n",
|
||||
" supressed_act = hs[1:] * supressed_mask\n",
|
||||
" return supressed_act"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from activation_store.collect import default_postprocess_result\n",
|
||||
"\n",
|
||||
"Wo = model.get_output_embeddings().weight.detach().clone().cpu()\n",
|
||||
"Wo_inv = torch.pinverse(Wo.clone().float())\n",
|
||||
"\n",
|
||||
"@torch.no_grad()\n",
|
||||
"def sup_postproc(input, trace, output, model):\n",
|
||||
"\n",
|
||||
" \n",
|
||||
" o = default_postprocess_result(input, trace, output, model)\n",
|
||||
" \n",
|
||||
" hs = o.pop('hidden_states')\n",
|
||||
" hs = rearrange(hs, \"b l t h -> l b t h\")\n",
|
||||
" hs_s = get_supressed_activations(hs, Wo.to(hs.dtype), Wo_inv.to(hs.dtype))\n",
|
||||
" hs_s = rearrange(hs_s, \"l b t h -> b l t h\")\n",
|
||||
" o['hidden_states_supressed'] = hs_s.half()\n",
|
||||
" \n",
|
||||
" return o\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 86,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"\u001b[32m2025-02-16 09:52:12.917\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mactivation_store.collect\u001b[0m:\u001b[36mactivation_store\u001b[0m:\u001b[36m78\u001b[0m - \u001b[1mcreating dataset /media/wassname/SGIronWolf/projects5/elk/cache_transformer_acts/outputs/.ds/ds__115ab10dde7bd7a3.parquet\u001b[0m\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"application/vnd.jupyter.widget-view+json": {
|
||||
"model_id": "c0c38f37f9934a0dbe7086b695624548",
|
||||
"version_major": 2,
|
||||
"version_minor": 0
|
||||
},
|
||||
"text/plain": [
|
||||
"collecting activations: 0%| | 0/5 [00:00<?, ?it/s]"
|
||||
]
|
||||
},
|
||||
"metadata": {},
|
||||
"output_type": "display_data"
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"PosixPath('/media/wassname/SGIronWolf/projects5/elk/cache_transformer_acts/outputs/.ds/ds__115ab10dde7bd7a3.parquet')"
|
||||
]
|
||||
},
|
||||
"execution_count": 86,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"f2 = activation_store(ds, model, postprocess_result=sup_postproc)\n",
|
||||
"f2"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"application/vnd.jupyter.widget-view+json": {
|
||||
"model_id": "2380a4d27fdb42e9a61adb67bd221cd8",
|
||||
"version_major": 2,
|
||||
"version_minor": 0
|
||||
},
|
||||
"text/plain": [
|
||||
"Generating train split: 0 examples [00:00, ? examples/s]"
|
||||
]
|
||||
},
|
||||
"metadata": {},
|
||||
"output_type": "display_data"
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"Dataset({\n",
|
||||
" features: ['attention_mask', 'logits', 'hidden_states_supressed'],\n",
|
||||
" num_rows: 20\n",
|
||||
"})"
|
||||
]
|
||||
},
|
||||
"execution_count": 87,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"ds_a2 = Dataset.from_parquet(str(f2)).with_format(\"torch\")\n",
|
||||
"ds_a2"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 92,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"torch.Size([2, 24, 453, 896])"
|
||||
]
|
||||
},
|
||||
"execution_count": 92,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"ds_a2.info\n",
|
||||
"ds_a2[0:2]['hidden_states_supressed'].shape"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 93,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"tensor([[[[ 0.0000e+00, 0.0000e+00, -0.0000e+00, ..., 0.0000e+00,\n",
|
||||
" 0.0000e+00, 0.0000e+00],\n",
|
||||
" [-0.0000e+00, -0.0000e+00, -0.0000e+00, ..., 0.0000e+00,\n",
|
||||
" -0.0000e+00, -0.0000e+00],\n",
|
||||
" [-0.0000e+00, 0.0000e+00, 0.0000e+00, ..., -0.0000e+00,\n",
|
||||
" -0.0000e+00, -0.0000e+00],\n",
|
||||
" ...,\n",
|
||||
" [ 0.0000e+00, -0.0000e+00, -0.0000e+00, ..., -0.0000e+00,\n",
|
||||
" 0.0000e+00, 0.0000e+00],\n",
|
||||
" [-0.0000e+00, 0.0000e+00, -0.0000e+00, ..., -0.0000e+00,\n",
|
||||
" 0.0000e+00, 0.0000e+00],\n",
|
||||
" [-0.0000e+00, 0.0000e+00, -0.0000e+00, ..., -0.0000e+00,\n",
|
||||
" 0.0000e+00, -0.0000e+00]],\n",
|
||||
"\n",
|
||||
" [[ 0.0000e+00, -0.0000e+00, -0.0000e+00, ..., 0.0000e+00,\n",
|
||||
" 0.0000e+00, -0.0000e+00],\n",
|
||||
" [-0.0000e+00, -0.0000e+00, 0.0000e+00, ..., -0.0000e+00,\n",
|
||||
" -0.0000e+00, -0.0000e+00],\n",
|
||||
" [-0.0000e+00, 0.0000e+00, 0.0000e+00, ..., -0.0000e+00,\n",
|
||||
" -0.0000e+00, 0.0000e+00],\n",
|
||||
" ...,\n",
|
||||
" [ 0.0000e+00, 0.0000e+00, -0.0000e+00, ..., -0.0000e+00,\n",
|
||||
" 0.0000e+00, 0.0000e+00],\n",
|
||||
" [-0.0000e+00, -0.0000e+00, -0.0000e+00, ..., -0.0000e+00,\n",
|
||||
" -0.0000e+00, 0.0000e+00],\n",
|
||||
" [-0.0000e+00, 0.0000e+00, -0.0000e+00, ..., 0.0000e+00,\n",
|
||||
" 0.0000e+00, -0.0000e+00]],\n",
|
||||
"\n",
|
||||
" [[-0.0000e+00, -0.0000e+00, -2.9629e+00, ..., -0.0000e+00,\n",
|
||||
" -0.0000e+00, 0.0000e+00],\n",
|
||||
" [-0.0000e+00, -0.0000e+00, -1.9275e-01, ..., -0.0000e+00,\n",
|
||||
" -0.0000e+00, 0.0000e+00],\n",
|
||||
" [-0.0000e+00, 0.0000e+00, 9.8419e-03, ..., -0.0000e+00,\n",
|
||||
" -0.0000e+00, 0.0000e+00],\n",
|
||||
" ...,\n",
|
||||
" [ 0.0000e+00, -0.0000e+00, 3.3855e-03, ..., -0.0000e+00,\n",
|
||||
" 0.0000e+00, 0.0000e+00],\n",
|
||||
" [-0.0000e+00, -0.0000e+00, -9.2627e-01, ..., -0.0000e+00,\n",
|
||||
" -0.0000e+00, 0.0000e+00],\n",
|
||||
" [-0.0000e+00, 0.0000e+00, -1.6125e-01, ..., 0.0000e+00,\n",
|
||||
" 0.0000e+00, -0.0000e+00]],\n",
|
||||
"\n",
|
||||
" ...,\n",
|
||||
"\n",
|
||||
" [[ 0.0000e+00, -2.7559e+00, -5.6543e-01, ..., -0.0000e+00,\n",
|
||||
" 0.0000e+00, 1.0596e+00],\n",
|
||||
" [-0.0000e+00, 6.6846e-01, -3.3643e-01, ..., -0.0000e+00,\n",
|
||||
" 0.0000e+00, 4.6387e-01],\n",
|
||||
" [ 0.0000e+00, 7.1484e-01, 2.5854e-01, ..., 0.0000e+00,\n",
|
||||
" 0.0000e+00, 1.4053e+00],\n",
|
||||
" ...,\n",
|
||||
" [ 0.0000e+00, 1.0898e+00, 4.5239e-01, ..., 0.0000e+00,\n",
|
||||
" -0.0000e+00, -2.1033e-01],\n",
|
||||
" [-0.0000e+00, 2.1504e+00, -1.2415e-01, ..., -0.0000e+00,\n",
|
||||
" -0.0000e+00, 1.3428e+00],\n",
|
||||
" [-0.0000e+00, 2.8262e+00, 5.9277e-01, ..., 0.0000e+00,\n",
|
||||
" -0.0000e+00, -5.5811e-01]],\n",
|
||||
"\n",
|
||||
" [[ 0.0000e+00, -2.4727e+00, 0.0000e+00, ..., -0.0000e+00,\n",
|
||||
" -7.0947e-01, 0.0000e+00],\n",
|
||||
" [-0.0000e+00, 1.3535e+00, -0.0000e+00, ..., -0.0000e+00,\n",
|
||||
" 1.1560e-01, -0.0000e+00],\n",
|
||||
" [ 0.0000e+00, 1.1572e+00, 0.0000e+00, ..., 0.0000e+00,\n",
|
||||
" 1.0020e+00, 0.0000e+00],\n",
|
||||
" ...,\n",
|
||||
" [ 0.0000e+00, 1.0254e+00, 0.0000e+00, ..., 0.0000e+00,\n",
|
||||
" -5.9326e-02, -0.0000e+00],\n",
|
||||
" [-0.0000e+00, 2.2461e+00, -0.0000e+00, ..., -0.0000e+00,\n",
|
||||
" -2.3105e+00, 0.0000e+00],\n",
|
||||
" [ 0.0000e+00, 1.5420e+00, 0.0000e+00, ..., 0.0000e+00,\n",
|
||||
" -5.2979e-01, -0.0000e+00]],\n",
|
||||
"\n",
|
||||
" [[ 0.0000e+00, -0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n",
|
||||
" -2.5605e+00, 6.3984e+00],\n",
|
||||
" [-0.0000e+00, -0.0000e+00, -0.0000e+00, ..., -0.0000e+00,\n",
|
||||
" -1.1279e+00, -1.1766e+01],\n",
|
||||
" [-0.0000e+00, 0.0000e+00, 0.0000e+00, ..., -0.0000e+00,\n",
|
||||
" 1.3135e+00, -3.2148e+00],\n",
|
||||
" ...,\n",
|
||||
" [-0.0000e+00, 0.0000e+00, -0.0000e+00, ..., 0.0000e+00,\n",
|
||||
" 7.1289e-01, -8.5234e+00],\n",
|
||||
" [-0.0000e+00, 0.0000e+00, -0.0000e+00, ..., -0.0000e+00,\n",
|
||||
" -1.0047e+01, 1.1797e+01],\n",
|
||||
" [ 0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n",
|
||||
" -1.9268e+00, -9.5078e+00]]],\n",
|
||||
"\n",
|
||||
"\n",
|
||||
" [[[ 0.0000e+00, 0.0000e+00, -0.0000e+00, ..., 1.1493e-01,\n",
|
||||
" 0.0000e+00, 0.0000e+00],\n",
|
||||
" [-0.0000e+00, -0.0000e+00, -0.0000e+00, ..., 7.6843e-02,\n",
|
||||
" -0.0000e+00, -0.0000e+00],\n",
|
||||
" [-0.0000e+00, 0.0000e+00, 0.0000e+00, ..., -1.0968e-01,\n",
|
||||
" -0.0000e+00, -0.0000e+00],\n",
|
||||
" ...,\n",
|
||||
" [ 0.0000e+00, 0.0000e+00, -0.0000e+00, ..., -2.1094e-01,\n",
|
||||
" -0.0000e+00, -0.0000e+00],\n",
|
||||
" [ 0.0000e+00, 0.0000e+00, -0.0000e+00, ..., -2.5928e-01,\n",
|
||||
" -0.0000e+00, -0.0000e+00],\n",
|
||||
" [ 0.0000e+00, 0.0000e+00, -0.0000e+00, ..., -2.5220e-01,\n",
|
||||
" -0.0000e+00, -0.0000e+00]],\n",
|
||||
"\n",
|
||||
" [[ 0.0000e+00, -0.0000e+00, -0.0000e+00, ..., 0.0000e+00,\n",
|
||||
" 0.0000e+00, -0.0000e+00],\n",
|
||||
" [-0.0000e+00, -0.0000e+00, 0.0000e+00, ..., -0.0000e+00,\n",
|
||||
" -0.0000e+00, -0.0000e+00],\n",
|
||||
" [-0.0000e+00, 0.0000e+00, 0.0000e+00, ..., -0.0000e+00,\n",
|
||||
" -0.0000e+00, 0.0000e+00],\n",
|
||||
" ...,\n",
|
||||
" [ 0.0000e+00, 0.0000e+00, -0.0000e+00, ..., -0.0000e+00,\n",
|
||||
" -0.0000e+00, 0.0000e+00],\n",
|
||||
" [ 0.0000e+00, 0.0000e+00, -0.0000e+00, ..., -0.0000e+00,\n",
|
||||
" -0.0000e+00, 0.0000e+00],\n",
|
||||
" [ 0.0000e+00, 0.0000e+00, -0.0000e+00, ..., -0.0000e+00,\n",
|
||||
" -0.0000e+00, 0.0000e+00]],\n",
|
||||
"\n",
|
||||
" [[-0.0000e+00, -8.0391e+00, -0.0000e+00, ..., -0.0000e+00,\n",
|
||||
" -0.0000e+00, 3.9531e+00],\n",
|
||||
" [-0.0000e+00, -4.0186e-01, -0.0000e+00, ..., -0.0000e+00,\n",
|
||||
" -0.0000e+00, 1.0168e-01],\n",
|
||||
" [-0.0000e+00, 1.7993e-01, 0.0000e+00, ..., -0.0000e+00,\n",
|
||||
" -0.0000e+00, 3.4561e-03],\n",
|
||||
" ...,\n",
|
||||
" [ 0.0000e+00, -2.8122e-02, -0.0000e+00, ..., -0.0000e+00,\n",
|
||||
" -0.0000e+00, -6.9031e-02],\n",
|
||||
" [ 0.0000e+00, -6.8909e-02, -0.0000e+00, ..., -0.0000e+00,\n",
|
||||
" -0.0000e+00, -6.5613e-02],\n",
|
||||
" [ 0.0000e+00, -5.5481e-02, -0.0000e+00, ..., -0.0000e+00,\n",
|
||||
" -0.0000e+00, -9.4604e-02]],\n",
|
||||
"\n",
|
||||
" ...,\n",
|
||||
"\n",
|
||||
" [[ 4.4263e-01, -0.0000e+00, -5.6543e-01, ..., -0.0000e+00,\n",
|
||||
" 0.0000e+00, 0.0000e+00],\n",
|
||||
" [-2.8281e+00, 0.0000e+00, -3.3643e-01, ..., -0.0000e+00,\n",
|
||||
" 0.0000e+00, 0.0000e+00],\n",
|
||||
" [ 6.2225e-02, 0.0000e+00, 2.5854e-01, ..., 0.0000e+00,\n",
|
||||
" 0.0000e+00, 0.0000e+00],\n",
|
||||
" ...,\n",
|
||||
" [-8.7598e-01, 0.0000e+00, 5.7422e-01, ..., 0.0000e+00,\n",
|
||||
" 0.0000e+00, 0.0000e+00],\n",
|
||||
" [-1.0537e+00, 0.0000e+00, 3.3667e-01, ..., 0.0000e+00,\n",
|
||||
" 0.0000e+00, 0.0000e+00],\n",
|
||||
" [-1.1729e+00, 0.0000e+00, 2.0654e-01, ..., 0.0000e+00,\n",
|
||||
" 0.0000e+00, 0.0000e+00]],\n",
|
||||
"\n",
|
||||
" [[ 0.0000e+00, -0.0000e+00, 0.0000e+00, ..., -0.0000e+00,\n",
|
||||
" -0.0000e+00, 1.6553e+00],\n",
|
||||
" [-0.0000e+00, 0.0000e+00, -0.0000e+00, ..., -0.0000e+00,\n",
|
||||
" 0.0000e+00, -3.4204e-01],\n",
|
||||
" [ 0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n",
|
||||
" 0.0000e+00, 2.3914e-01],\n",
|
||||
" ...,\n",
|
||||
" [ 0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n",
|
||||
" 0.0000e+00, -8.4375e-01],\n",
|
||||
" [-0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n",
|
||||
" 0.0000e+00, -9.8193e-01],\n",
|
||||
" [-0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n",
|
||||
" 0.0000e+00, -9.7900e-01]],\n",
|
||||
"\n",
|
||||
" [[ 0.0000e+00, -6.5742e+00, 0.0000e+00, ..., 0.0000e+00,\n",
|
||||
" -0.0000e+00, 6.3984e+00],\n",
|
||||
" [-0.0000e+00, -3.4448e-01, -0.0000e+00, ..., -0.0000e+00,\n",
|
||||
" -0.0000e+00, -1.1766e+01],\n",
|
||||
" [-0.0000e+00, 6.8408e-01, 0.0000e+00, ..., -0.0000e+00,\n",
|
||||
" 0.0000e+00, -3.2148e+00],\n",
|
||||
" ...,\n",
|
||||
" [ 0.0000e+00, -1.3831e-01, 0.0000e+00, ..., 0.0000e+00,\n",
|
||||
" 0.0000e+00, -1.2625e+01],\n",
|
||||
" [ 0.0000e+00, -1.0439e+00, 0.0000e+00, ..., 0.0000e+00,\n",
|
||||
" 0.0000e+00, -9.9453e+00],\n",
|
||||
" [ 0.0000e+00, -1.5166e+00, 0.0000e+00, ..., 0.0000e+00,\n",
|
||||
" 0.0000e+00, -1.1367e+01]]]])"
|
||||
]
|
||||
},
|
||||
"execution_count": 93,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"ds_a2[0:2]['hidden_states_supressed']"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
|
||||
Reference in New Issue
Block a user