diff --git a/index.ts b/index.ts index 2ca47eb..6516206 100644 --- a/index.ts +++ b/index.ts @@ -79,7 +79,7 @@ export default function (pi: ExtensionAPI) { }) pi.on("context", async (event, ctx) => { - const prunedMessages = await applyPruning(event.messages, state, config, (ctx as any).apiClient) + const prunedMessages = await applyPruning(event.messages, state, config, ctx.model) return { messages: prunedMessages } }) diff --git a/pruner.ts b/pruner.ts index 69c3234..4ecacd0 100644 --- a/pruner.ts +++ b/pruner.ts @@ -132,7 +132,7 @@ async function generateSummary( turns: any[], previousSummary: string | null, focusTopic: string | null, - apiClient: any, + model: any, ): Promise { const contentToSummarize = serializeForSummary(turns); @@ -213,13 +213,23 @@ Prioritize preserving all information related to the focus topic.`; } try { - const model = (apiClient && apiClient.model) ? apiClient.model : "gemini-2.0-flash"; - const response = await apiClient.chat.completions.create({ - model: model, - messages: [{ role: "user", content: prompt }], - max_tokens: 4000, + if (!model) return null; + const piAi = await import("@mariozechner/pi-ai"); + const response = await piAi.complete(model, { + messages: [{ role: "user", content: prompt, timestamp: Date.now() }] }); - return response.choices[0]?.message?.content?.trim() || null; + + let text = ""; + if (Array.isArray(response.content)) { + text = response.content + .filter((c: any) => c.type === "text") + .map((c: any) => c.text) + .join(""); + } else if (typeof (response as any).content === "string") { + text = (response as any).content; + } + + return text.trim() || null; } catch (e) { console.error("Summary generation failed:", e); return null; @@ -333,7 +343,7 @@ export async function applyPruning( messages: any[], state: DcpState, config: DcpConfig, - apiClient: any + model: any ): Promise { const msgs = messages.map((m: any) => { const clone = { ...m }; @@ -369,7 +379,7 @@ export async function applyPruning( if (compressStart < compressEnd) { const middle = msgs.slice(compressStart, compressEnd); - const summary = await generateSummary(middle, state.previousSummary, null, apiClient); + const summary = await generateSummary(middle, state.previousSummary, null, model); if (summary) { const compressed: any[] = [];