mirror of
https://github.com/wassname/Open-Assistant.git
synced 2026-06-27 16:10:30 +08:00
fix steak interaction changes
This commit is contained in:
@@ -8,7 +8,7 @@ from oasst_backend.api import deps
|
||||
from oasst_backend.prompt_repository import PromptRepository, TaskRepository
|
||||
from oasst_backend.tree_manager import TreeManager
|
||||
from oasst_backend.user_repository import UserRepository
|
||||
from oasst_backend.utils.database_utils import CommitMode, managed_tx_function
|
||||
from oasst_backend.utils.database_utils import CommitMode, async_managed_tx_function
|
||||
from oasst_shared.exceptions import OasstError, OasstErrorCode
|
||||
from oasst_shared.schemas import protocol as protocol_schema
|
||||
from sqlmodel import Session
|
||||
@@ -129,7 +129,6 @@ def tasks_acknowledge_failure(
|
||||
@router.post("/interaction", response_model=protocol_schema.TaskDone)
|
||||
async def tasks_interaction(
|
||||
*,
|
||||
db: Session = Depends(deps.get_db),
|
||||
api_key: APIKey = Depends(deps.get_api_key),
|
||||
interaction: protocol_schema.AnyInteraction,
|
||||
) -> Any:
|
||||
@@ -137,16 +136,15 @@ async def tasks_interaction(
|
||||
The frontend reports an interaction.
|
||||
"""
|
||||
|
||||
@managed_tx_function(CommitMode.COMMIT)
|
||||
@async_managed_tx_function(CommitMode.COMMIT)
|
||||
async def interaction_tx(session: deps.Session):
|
||||
api_client = deps.api_auth(api_key, db)
|
||||
pr = PromptRepository(db, api_client, client_user=interaction.user)
|
||||
tm = TreeManager(db, pr)
|
||||
ur = UserRepository(db, api_client)
|
||||
api_client = deps.api_auth(api_key, session)
|
||||
pr = PromptRepository(session, api_client, client_user=interaction.user)
|
||||
tm = TreeManager(session, pr)
|
||||
ur = UserRepository(session, api_client)
|
||||
task = await tm.handle_interaction(interaction)
|
||||
match (type(task)):
|
||||
case protocol_schema.TaskDone:
|
||||
ur.update_user_last_activity(client_user=interaction.user)
|
||||
if type(task) is protocol_schema.TaskDone:
|
||||
ur.update_user_last_activity(user=pr.user)
|
||||
return task
|
||||
|
||||
try:
|
||||
|
||||
@@ -312,10 +312,6 @@ class UserRepository:
|
||||
return qry.all()
|
||||
|
||||
@managed_tx_method(CommitMode.FLUSH)
|
||||
def update_user_last_activity(self, client_user: protocol_schema.User) -> None:
|
||||
user = self.lookup_client_user(client_user=client_user, create_missing=False)
|
||||
if user is None:
|
||||
raise OasstError("User not found", OasstErrorCode.USER_NOT_FOUND, HTTP_404_NOT_FOUND)
|
||||
|
||||
def update_user_last_activity(self, user: User) -> None:
|
||||
user.last_activity_date = utcnow()
|
||||
self.db.add(user)
|
||||
|
||||
@@ -203,3 +203,64 @@ def managed_tx_function(
|
||||
return wrapped_f
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
def async_managed_tx_function(
|
||||
auto_commit: CommitMode = CommitMode.COMMIT,
|
||||
num_retries=settings.DATABASE_MAX_TX_RETRY_COUNT,
|
||||
session_factory: Callable[..., Session] = default_session_factor,
|
||||
):
|
||||
"""Passes Session object as first argument to wrapped function."""
|
||||
|
||||
def decorator(f):
|
||||
@wraps(f)
|
||||
async def wrapped_f(*args, **kwargs):
|
||||
try:
|
||||
result = None
|
||||
if auto_commit == CommitMode.COMMIT:
|
||||
retry_exhausted = True
|
||||
for i in range(num_retries):
|
||||
with session_factory() as session:
|
||||
try:
|
||||
result = await f(session, *args, **kwargs)
|
||||
session.commit()
|
||||
if isinstance(result, SQLModel):
|
||||
session.refresh(result)
|
||||
retry_exhausted = False
|
||||
break
|
||||
except PendingRollbackError as e:
|
||||
logger.info(str(e))
|
||||
session.rollback()
|
||||
except OperationalError as e:
|
||||
if e.orig is not None and isinstance(
|
||||
e.orig,
|
||||
(SerializationFailure, DeadlockDetected, UniqueViolation, ExclusionViolation),
|
||||
):
|
||||
logger.info(f"{type(e.orig)} Inner {e.orig.pgcode} {type(e.orig.pgcode)}")
|
||||
session.rollback()
|
||||
else:
|
||||
raise e
|
||||
logger.info(f"Retry {i+1}/{num_retries}")
|
||||
if retry_exhausted:
|
||||
raise OasstError(
|
||||
"DATABASE_MAX_RETIRES_EXHAUSTED",
|
||||
error_code=OasstErrorCode.DATABASE_MAX_RETRIES_EXHAUSTED,
|
||||
http_status_code=HTTPStatus.SERVICE_UNAVAILABLE,
|
||||
)
|
||||
else:
|
||||
with session_factory() as session:
|
||||
result = await f(session, *args, **kwargs)
|
||||
if auto_commit == CommitMode.FLUSH:
|
||||
session.flush()
|
||||
if isinstance(result, SQLModel):
|
||||
session.refresh(result)
|
||||
elif auto_commit == CommitMode.ROLLBACK:
|
||||
session.rollback()
|
||||
return result
|
||||
except Exception as e:
|
||||
logger.info(str(e))
|
||||
raise e
|
||||
|
||||
return wrapped_f
|
||||
|
||||
return decorator
|
||||
|
||||
Reference in New Issue
Block a user