189 lines
5.0 KiB
TypeScript
189 lines
5.0 KiB
TypeScript
import {
|
|
OllamaRequest,
|
|
OllamaResponse,
|
|
NewStatusBody,
|
|
Notification,
|
|
WSEvent,
|
|
} from "../types.js";
|
|
import striptags from "striptags";
|
|
import { PrismaClient } from "../generated/prisma/client.js";
|
|
import { createWebsocket } from "./websocket.js";
|
|
|
|
const prisma = new PrismaClient();
|
|
|
|
const storeUserData = async (notification: Notification): Promise<void> => {
|
|
try {
|
|
await prisma.user.upsert({
|
|
where: { userFqn: notification.status.account.fqn },
|
|
update: {
|
|
lastRespondedTo: new Date(Date.now()),
|
|
},
|
|
create: {
|
|
userFqn: notification.status.account.fqn,
|
|
lastRespondedTo: new Date(Date.now()),
|
|
},
|
|
});
|
|
} catch (error: any) {
|
|
throw new Error(error.message);
|
|
}
|
|
};
|
|
|
|
const alreadyRespondedTo = async (
|
|
notification: Notification
|
|
): Promise<boolean> => {
|
|
try {
|
|
const duplicate = await prisma.response.findFirst({
|
|
where: { pleromaNotificationId: notification.status.id },
|
|
});
|
|
if (duplicate) {
|
|
return true;
|
|
}
|
|
return false;
|
|
} catch (error: any) {
|
|
throw new Error(error.message);
|
|
}
|
|
};
|
|
|
|
const storePromptData = async (
|
|
notification: Notification,
|
|
ollamaResponseBody: OllamaResponse
|
|
) => {
|
|
try {
|
|
await prisma.response.create({
|
|
data: {
|
|
response: ollamaResponseBody.response,
|
|
request: striptags(notification.status.content),
|
|
to: notification.account.fqn,
|
|
pleromaNotificationId: notification.status.id,
|
|
},
|
|
});
|
|
} catch (error: any) {
|
|
throw new Error(error.message);
|
|
}
|
|
};
|
|
|
|
const trimInputData = (input: string) => {
|
|
const strippedInput = striptags(input);
|
|
const split = strippedInput.split(" ");
|
|
const promptStringIndex = split.indexOf("!prompt");
|
|
return split.slice(promptStringIndex + 1).join(" "); // returns everything after the !prompt
|
|
};
|
|
|
|
const generateOllamaRequest = async (
|
|
notification: Notification
|
|
): Promise<OllamaResponse | undefined> => {
|
|
try {
|
|
if (
|
|
striptags(notification.status.content).includes("!prompt") &&
|
|
!notification.status.account.bot
|
|
) {
|
|
if (
|
|
process.env.ONLY_LOCAL_REPLIES === "true" &&
|
|
!notification.status.account.fqn.includes(
|
|
`@${process.env.PLEROMA_INSTANCE_DOMAIN}`
|
|
)
|
|
) {
|
|
return;
|
|
}
|
|
if (await alreadyRespondedTo(notification)) {
|
|
return;
|
|
}
|
|
await storeUserData(notification);
|
|
const ollamaRequestBody: OllamaRequest = {
|
|
model: process.env.OLLAMA_MODEL as string,
|
|
system: process.env.OLLAMA_SYSTEM_PROMPT as string,
|
|
prompt: `@${notification.status.account.fqn} says: ${trimInputData(
|
|
notification.status.content
|
|
)}`,
|
|
stream: false,
|
|
};
|
|
const response = await fetch(`${process.env.OLLAMA_URL}/api/generate`, {
|
|
method: "POST",
|
|
body: JSON.stringify(ollamaRequestBody),
|
|
});
|
|
const ollamaResponse: OllamaResponse = await response.json();
|
|
await storePromptData(notification, ollamaResponse);
|
|
return ollamaResponse;
|
|
}
|
|
} catch (error: any) {
|
|
throw new Error(error.message);
|
|
}
|
|
};
|
|
|
|
const postReplyToStatus = async (
|
|
notification: Notification,
|
|
ollamaResponseBody: OllamaResponse
|
|
) => {
|
|
try {
|
|
let mentions: string[];
|
|
const statusBody: NewStatusBody = {
|
|
content_type: "text/markdown",
|
|
status: ollamaResponseBody.response,
|
|
in_reply_to_id: notification.status.id,
|
|
};
|
|
if (
|
|
notification.status.mentions &&
|
|
notification.status.mentions.length > 0
|
|
) {
|
|
mentions = notification.status.mentions.map((mention) => {
|
|
return mention.acct;
|
|
});
|
|
statusBody.to = mentions;
|
|
}
|
|
|
|
const response = await fetch(
|
|
`${process.env.PLEROMA_INSTANCE_URL}/api/v1/statuses`,
|
|
{
|
|
method: "POST",
|
|
headers: {
|
|
Authorization: `Bearer ${process.env.INSTANCE_BEARER_TOKEN}`,
|
|
"Content-Type": "application/json",
|
|
},
|
|
body: JSON.stringify(statusBody),
|
|
}
|
|
);
|
|
|
|
if (!response.ok) {
|
|
throw new Error(`New status request failed: ${response.statusText}`);
|
|
}
|
|
} catch (error: any) {
|
|
throw new Error(error.message);
|
|
}
|
|
};
|
|
|
|
const ws = createWebsocket();
|
|
|
|
ws.on("upgrade", () => {
|
|
console.log(
|
|
`Websocket connection to ${process.env.PLEROMA_INSTANCE_DOMAIN} successful.`
|
|
);
|
|
});
|
|
|
|
ws.on("close", (event: CloseEvent) => {
|
|
console.log(`Connection closed: ${event.reason}`);
|
|
});
|
|
|
|
ws.on("open", () => {
|
|
setInterval(() => {
|
|
ws.send(JSON.stringify({ type: "ping" }));
|
|
}, 20000);
|
|
});
|
|
|
|
ws.on("message", async (data) => {
|
|
try {
|
|
const message: WSEvent = JSON.parse(data.toString("utf-8"));
|
|
if (message.event !== "notification") {
|
|
// only watch for notification events
|
|
return;
|
|
}
|
|
console.log("Websocket message received.");
|
|
const payload = JSON.parse(message.payload) as Notification;
|
|
const ollamaResponse = await generateOllamaRequest(payload);
|
|
if (ollamaResponse) {
|
|
await postReplyToStatus(payload, ollamaResponse);
|
|
}
|
|
} catch (error: any) {
|
|
console.error(error.message);
|
|
}
|
|
});
|