mirror of
https://github.com/wassname/Open-Assistant.git
synced 2026-06-29 16:30:24 +08:00
ead51ff423
* only ranking assistant replies by default * add tasks/availability endpoint allow to specify desired task * move rank_prompter_replies option to TreeManagerConfiguration * fix type annotation * remove desired_task_type from _random_task_selection() * fix typo * Convert query_tree_size to sqlachemy, return 'full' text-labeling tasks if they were explicitly requested
156 lines
5.0 KiB
Python
156 lines
5.0 KiB
Python
from typing import Any, Optional
|
|
from uuid import UUID
|
|
|
|
from fastapi import APIRouter, Depends
|
|
from fastapi.security.api_key import APIKey
|
|
from loguru import logger
|
|
from oasst_backend.api import deps
|
|
from oasst_backend.prompt_repository import PromptRepository, TaskRepository
|
|
from oasst_backend.tree_manager import TreeManager
|
|
from oasst_shared.exceptions import OasstError, OasstErrorCode
|
|
from oasst_shared.schemas import protocol as protocol_schema
|
|
from sqlmodel import Session
|
|
from starlette.status import HTTP_204_NO_CONTENT
|
|
|
|
router = APIRouter()
|
|
|
|
|
|
@router.post(
|
|
"/",
|
|
response_model=protocol_schema.AnyTask,
|
|
dependencies=[
|
|
Depends(deps.UserRateLimiter(times=100, minutes=5)),
|
|
Depends(deps.APIClientRateLimiter(times=10_000, minutes=1)),
|
|
],
|
|
) # work with Union once more types are added
|
|
def request_task(
|
|
*,
|
|
db: Session = Depends(deps.get_db),
|
|
api_key: APIKey = Depends(deps.get_api_key),
|
|
request: protocol_schema.TaskRequest,
|
|
) -> Any:
|
|
"""
|
|
Create new task.
|
|
"""
|
|
api_client = deps.api_auth(api_key, db)
|
|
|
|
try:
|
|
pr = PromptRepository(db, api_client, client_user=request.user)
|
|
tm = TreeManager(db, pr)
|
|
task, message_tree_id, parent_message_id = tm.next_task(request.type)
|
|
pr.task_repository.store_task(task, message_tree_id, parent_message_id, request.collective)
|
|
|
|
except OasstError:
|
|
raise
|
|
except Exception:
|
|
logger.exception("Failed to generate task..")
|
|
raise OasstError("Failed to generate task.", OasstErrorCode.TASK_GENERATION_FAILED)
|
|
return task
|
|
|
|
|
|
@router.post("/availability", response_model=dict[protocol_schema.TaskRequestType, int])
|
|
def tasks_availability(
|
|
*,
|
|
user: Optional[protocol_schema.User] = None,
|
|
db: Session = Depends(deps.get_db),
|
|
api_key: APIKey = Depends(deps.get_api_key),
|
|
):
|
|
api_client = deps.api_auth(api_key, db)
|
|
|
|
try:
|
|
pr = PromptRepository(db, api_client, client_user=user)
|
|
tm = TreeManager(db, pr)
|
|
return tm.determine_task_availability()
|
|
|
|
except OasstError:
|
|
raise
|
|
except Exception:
|
|
logger.exception("Task availability query failed.")
|
|
raise OasstError("Task availability query failed.", OasstErrorCode.TASK_AVAILABILITY_QUERY_FAILED)
|
|
|
|
|
|
@router.post("/{task_id}/ack", response_model=None, status_code=HTTP_204_NO_CONTENT)
|
|
def tasks_acknowledge(
|
|
*,
|
|
db: Session = Depends(deps.get_db),
|
|
api_key: APIKey = Depends(deps.get_api_key),
|
|
task_id: UUID,
|
|
ack_request: protocol_schema.TaskAck,
|
|
) -> None:
|
|
"""
|
|
The frontend acknowledges a task.
|
|
"""
|
|
|
|
api_client = deps.api_auth(api_key, db)
|
|
|
|
try:
|
|
pr = PromptRepository(db, api_client)
|
|
|
|
# here we store the message id in the database for the task
|
|
logger.info(f"Frontend acknowledges task {task_id=}, {ack_request=}.")
|
|
pr.task_repository.bind_frontend_message_id(task_id=task_id, frontend_message_id=ack_request.message_id)
|
|
|
|
except OasstError:
|
|
raise
|
|
except Exception:
|
|
logger.exception("Failed to acknowledge task.")
|
|
raise OasstError("Failed to acknowledge task.", OasstErrorCode.TASK_ACK_FAILED)
|
|
|
|
|
|
@router.post("/{task_id}/nack", response_model=None, status_code=HTTP_204_NO_CONTENT)
|
|
def tasks_acknowledge_failure(
|
|
*,
|
|
db: Session = Depends(deps.get_db),
|
|
api_key: APIKey = Depends(deps.get_api_key),
|
|
task_id: UUID,
|
|
nack_request: protocol_schema.TaskNAck,
|
|
) -> None:
|
|
"""
|
|
The frontend reports failure to implement a task.
|
|
"""
|
|
|
|
try:
|
|
logger.info(f"Frontend reports failure to implement task {task_id=}, {nack_request=}.")
|
|
api_client = deps.api_auth(api_key, db)
|
|
pr = PromptRepository(db, api_client)
|
|
pr.task_repository.acknowledge_task_failure(task_id)
|
|
except (KeyError, RuntimeError):
|
|
logger.exception("Failed to not acknowledge task.")
|
|
raise OasstError("Failed to not acknowledge task.", OasstErrorCode.TASK_NACK_FAILED)
|
|
|
|
|
|
@router.post("/interaction", response_model=protocol_schema.TaskDone)
|
|
async def tasks_interaction(
|
|
*,
|
|
db: Session = Depends(deps.get_db),
|
|
api_key: APIKey = Depends(deps.get_api_key),
|
|
interaction: protocol_schema.AnyInteraction,
|
|
) -> Any:
|
|
"""
|
|
The frontend reports an interaction.
|
|
"""
|
|
api_client = deps.api_auth(api_key, db)
|
|
|
|
try:
|
|
pr = PromptRepository(db, api_client, client_user=interaction.user)
|
|
tm = TreeManager(db, pr)
|
|
return await tm.handle_interaction(interaction)
|
|
|
|
except OasstError:
|
|
raise
|
|
except Exception:
|
|
logger.exception("Interaction request failed.")
|
|
raise OasstError("Interaction request failed.", OasstErrorCode.TASK_INTERACTION_REQUEST_FAILED)
|
|
|
|
|
|
@router.post("/close", response_model=protocol_schema.TaskDone)
|
|
def close_collective_task(
|
|
close_task_request: protocol_schema.TaskClose,
|
|
db: Session = Depends(deps.get_db),
|
|
api_key: APIKey = Depends(deps.get_api_key),
|
|
):
|
|
api_client = deps.api_auth(api_key, db)
|
|
tr = TaskRepository(db, api_client)
|
|
tr.close_task(close_task_request.message_id)
|
|
return protocol_schema.TaskDone()
|