Files
pleroma-ollama-bot/src/main.ts

189 lines
5.0 KiB
TypeScript

import {
OllamaRequest,
OllamaResponse,
NewStatusBody,
Notification,
WSEvent,
} from "../types.js";
import striptags from "striptags";
import { PrismaClient } from "../generated/prisma/client.js";
import { createWebsocket } from "./websocket.js";
const prisma = new PrismaClient();
const storeUserData = async (notification: Notification): Promise<void> => {
try {
await prisma.user.upsert({
where: { userFqn: notification.status.account.fqn },
update: {
lastRespondedTo: new Date(Date.now()),
},
create: {
userFqn: notification.status.account.fqn,
lastRespondedTo: new Date(Date.now()),
},
});
} catch (error: any) {
throw new Error(error.message);
}
};
const alreadyRespondedTo = async (
notification: Notification
): Promise<boolean> => {
try {
const duplicate = await prisma.response.findFirst({
where: { pleromaNotificationId: notification.status.id },
});
if (duplicate) {
return true;
}
return false;
} catch (error: any) {
throw new Error(error.message);
}
};
const storePromptData = async (
notification: Notification,
ollamaResponseBody: OllamaResponse
) => {
try {
await prisma.response.create({
data: {
response: ollamaResponseBody.response,
request: striptags(notification.status.content),
to: notification.account.fqn,
pleromaNotificationId: notification.status.id,
},
});
} catch (error: any) {
throw new Error(error.message);
}
};
const trimInputData = (input: string) => {
const strippedInput = striptags(input);
const split = strippedInput.split(" ");
const promptStringIndex = split.indexOf("!prompt");
return split.slice(promptStringIndex + 1).join(" "); // returns everything after the !prompt
};
const generateOllamaRequest = async (
notification: Notification
): Promise<OllamaResponse | undefined> => {
try {
if (
striptags(notification.status.content).includes("!prompt") &&
!notification.status.account.bot
) {
if (
process.env.ONLY_LOCAL_REPLIES === "true" &&
!notification.status.account.fqn.includes(
`@${process.env.PLEROMA_INSTANCE_DOMAIN}`
)
) {
return;
}
if (await alreadyRespondedTo(notification)) {
return;
}
await storeUserData(notification);
const ollamaRequestBody: OllamaRequest = {
model: process.env.OLLAMA_MODEL as string,
system: process.env.OLLAMA_SYSTEM_PROMPT as string,
prompt: `@${notification.status.account.fqn} says: ${trimInputData(
notification.status.content
)}`,
stream: false,
};
const response = await fetch(`${process.env.OLLAMA_URL}/api/generate`, {
method: "POST",
body: JSON.stringify(ollamaRequestBody),
});
const ollamaResponse: OllamaResponse = await response.json();
await storePromptData(notification, ollamaResponse);
return ollamaResponse;
}
} catch (error: any) {
throw new Error(error.message);
}
};
const postReplyToStatus = async (
notification: Notification,
ollamaResponseBody: OllamaResponse
) => {
try {
let mentions: string[];
const statusBody: NewStatusBody = {
content_type: "text/markdown",
status: ollamaResponseBody.response,
in_reply_to_id: notification.status.id,
};
if (
notification.status.mentions &&
notification.status.mentions.length > 0
) {
mentions = notification.status.mentions.map((mention) => {
return mention.acct;
});
statusBody.to = mentions;
}
const response = await fetch(
`${process.env.PLEROMA_INSTANCE_URL}/api/v1/statuses`,
{
method: "POST",
headers: {
Authorization: `Bearer ${process.env.INSTANCE_BEARER_TOKEN}`,
"Content-Type": "application/json",
},
body: JSON.stringify(statusBody),
}
);
if (!response.ok) {
throw new Error(`New status request failed: ${response.statusText}`);
}
} catch (error: any) {
throw new Error(error.message);
}
};
const ws = createWebsocket();
ws.on("upgrade", () => {
console.log(
`Websocket connection to ${process.env.PLEROMA_INSTANCE_DOMAIN} successful.`
);
});
ws.on("close", (event: CloseEvent) => {
console.log(`Connection closed: ${event.reason}`);
});
ws.on("open", () => {
setInterval(() => {
ws.send(JSON.stringify({ type: "ping" }));
}, 20000);
});
ws.on("message", async (data) => {
try {
const message: WSEvent = JSON.parse(data.toString("utf-8"));
if (message.event !== "notification") {
// only watch for notification events
return;
}
console.log("Websocket message received.");
const payload = JSON.parse(message.payload) as Notification;
const ollamaResponse = await generateOllamaRequest(payload);
if (ollamaResponse) {
await postReplyToStatus(payload, ollamaResponse);
}
} catch (error: any) {
console.error(error.message);
}
});