Merge pull request #572 from melvinebenezer/493_text_labels_api

#493 Use API to populate text_labels in frontend
This commit is contained in:
Keith Stevens
2023-01-10 08:08:47 +09:00
committed by GitHub
5 changed files with 79 additions and 49 deletions
+14 -45
View File
@@ -27,8 +27,22 @@ import poster from "src/lib/poster";
import { colors } from "styles/Theme/colors";
import useSWRMutation from "swr/mutation";
interface textFlagLabels {
attributeName: string;
labelText: string;
additionalExplanation?: string;
}
export const FlaggableElement = (props) => {
const [isEditing, setIsEditing] = useBoolean();
const flaggable_labels = props.flaggable_labels;
const TEXT_LABEL_FLAGS = flaggable_labels.valid_labels.map((valid_label) => {
return {
attributeName: valid_label.name,
labelText: valid_label.display_text,
additionalExplanation: valid_label.help_text,
};
});
const { trigger } = useSWRMutation("/api/set_label", poster, {
onSuccess: () => {
setIsEditing.off;
@@ -181,48 +195,3 @@ export function FlagCheckbox(props: {
</Flex>
);
}
interface textFlagLabels {
attributeName: string;
labelText: string;
additionalExplanation?: string;
}
const TEXT_LABEL_FLAGS: textFlagLabels[] = [
// For the time being this list is configured on the FE.
// In the future it may be provided by the API.
// {
// attributeName: "fails_task",
// labelText: "Fails to follow the correct instruction / task",
// additionalExplanation: "__TODO__",
// },
// {
// attributeName: "not_customer_assistant_appropriate",
// labelText: "Inappropriate for customer assistant",
// additionalExplanation: "__TODO__",
// },
{
attributeName: "sexual_content",
labelText: "Contains sexual content",
},
{
attributeName: "violence",
labelText: "Contains violent content",
},
// {
// attributeName: "encourages_violence",
// labelText: "Encourages or fails to discourage violence/abuse/terrorism/self-harm",
// },
// {
// attributeName: "denigrates_a_protected_class",
// labelText: "Denigrates a protected class",
// },
// {
// attributeName: "gives_harmful_advice",
// labelText: "Fails to follow the correct instruction / task",
// additionalExplanation:
// "The advice given in the output is harmful or counter-productive. This may be in addition to, but is distinct from the question about encouraging violence/abuse/terrorism/self-harm.",
// },
// {
// attributeName: "expresses_moral_judgement",
// labelText: "Expresses moral judgement",
// },
];
+22 -2
View File
@@ -10,11 +10,31 @@ export interface Message {
message_id: string;
}
export const Messages = ({ messages, post_id }: { messages: Message[]; post_id: string }) => {
export interface ValidLabel {
name: string;
display_text: string;
help_text: string;
}
export const Messages = ({
messages,
post_id,
valid_labels,
}: {
messages: Message[];
post_id: string;
valid_labels: ValidLabel[];
}) => {
const items = messages.map((messageProps: Message, i: number) => {
const { message_id, text } = messageProps;
return (
<FlaggableElement text={text} post_id={post_id} message_id={message_id} key={i + text}>
<FlaggableElement
text={text}
post_id={post_id}
message_id={message_id}
key={i + text}
flaggable_labels={valid_labels}
>
<MessageView {...messageProps} />
</FlaggableElement>
);
+4 -2
View File
@@ -18,7 +18,7 @@ export interface CreateTaskProps {
}
export const CreateTask = ({ tasks, taskType, trigger, onSkipTask, onNextTask, mainBgClasses }: CreateTaskProps) => {
const task = tasks[0].task;
const valid_labels = tasks[0].valid_labels;
const [inputText, setInputText] = useState("");
const submitResponse = (task: { id: string }) => {
@@ -42,7 +42,9 @@ export const CreateTask = ({ tasks, taskType, trigger, onSkipTask, onNextTask, m
<>
<h5 className="text-lg font-semibold">{taskType.label}</h5>
<p className="text-lg py-1">{taskType.overview}</p>
{task.conversation ? <Messages messages={task.conversation.messages} post_id={task.id} /> : null}
{task.conversation ? (
<Messages messages={task.conversation.messages} post_id={task.id} valid_labels={valid_labels} />
) : null}
</>
<>
<h5 className="text-lg font-semibold">{taskType.instruction}</h5>
+33
View File
@@ -48,6 +48,33 @@ export class OasstApiClient {
return await resp.json();
}
private async get(path: string): Promise<any> {
const resp = await fetch(`${this.oasstApiUrl}${path}`, {
method: "GET",
headers: {
"X-API-Key": this.oasstApiKey,
"Content-Type": "application/json",
},
});
if (resp.status == 204) {
return null;
}
if (resp.status >= 300) {
const errorText = await resp.text();
let error: any;
try {
error = JSON.parse(errorText);
} catch (e) {
throw new OasstError(errorText, 0, resp.status);
}
throw new OasstError(error.message ?? error, error.error_code, resp.status);
}
return await resp.json();
}
// TODO return a strongly typed Task?
// This method is used to store a task in RegisteredTask.task.
// This is a raw Json type, so we can't use it to strongly type the task.
@@ -96,6 +123,12 @@ export class OasstApiClient {
...content,
});
}
//Fetch valid labels. This is called every task. though the call may be redundant
//keeping this for future where the valid labels may change per task
async fetch_valid_text(): Promise<void> {
return this.get(`/api/v1/text_labels/valid_labels`);
}
}
export const oasstApiClient =
@@ -23,6 +23,7 @@ const handler = async (req, res) => {
// Fetch the new task.
const task = await oasstApiClient.fetchTask(task_type, token);
const valid_labels = await oasstApiClient.fetch_valid_text();
// Store the task and link it to the user..
const registeredTask = await prisma.registeredTask.create({
@@ -36,6 +37,11 @@ const handler = async (req, res) => {
},
});
// Add the valid labels that can be used to flag messages in this Task
registeredTask["valid_labels"] = valid_labels;
// Update the backend with our Task ID
await oasstApiClient.ackTask(task.id, registeredTask.id);
// Send the results to the client.
res.status(200).json(registeredTask);
};