fix steak interaction changes

This commit is contained in:
Andreas Köpf
2023-01-31 13:08:13 +01:00
parent 063157355c
commit 65d69dac66
3 changed files with 70 additions and 15 deletions
+8 -10
View File
@@ -8,7 +8,7 @@ from oasst_backend.api import deps
from oasst_backend.prompt_repository import PromptRepository, TaskRepository
from oasst_backend.tree_manager import TreeManager
from oasst_backend.user_repository import UserRepository
from oasst_backend.utils.database_utils import CommitMode, managed_tx_function
from oasst_backend.utils.database_utils import CommitMode, async_managed_tx_function
from oasst_shared.exceptions import OasstError, OasstErrorCode
from oasst_shared.schemas import protocol as protocol_schema
from sqlmodel import Session
@@ -129,7 +129,6 @@ def tasks_acknowledge_failure(
@router.post("/interaction", response_model=protocol_schema.TaskDone)
async def tasks_interaction(
*,
db: Session = Depends(deps.get_db),
api_key: APIKey = Depends(deps.get_api_key),
interaction: protocol_schema.AnyInteraction,
) -> Any:
@@ -137,16 +136,15 @@ async def tasks_interaction(
The frontend reports an interaction.
"""
@managed_tx_function(CommitMode.COMMIT)
@async_managed_tx_function(CommitMode.COMMIT)
async def interaction_tx(session: deps.Session):
api_client = deps.api_auth(api_key, db)
pr = PromptRepository(db, api_client, client_user=interaction.user)
tm = TreeManager(db, pr)
ur = UserRepository(db, api_client)
api_client = deps.api_auth(api_key, session)
pr = PromptRepository(session, api_client, client_user=interaction.user)
tm = TreeManager(session, pr)
ur = UserRepository(session, api_client)
task = await tm.handle_interaction(interaction)
match (type(task)):
case protocol_schema.TaskDone:
ur.update_user_last_activity(client_user=interaction.user)
if type(task) is protocol_schema.TaskDone:
ur.update_user_last_activity(user=pr.user)
return task
try:
+1 -5
View File
@@ -312,10 +312,6 @@ class UserRepository:
return qry.all()
@managed_tx_method(CommitMode.FLUSH)
def update_user_last_activity(self, client_user: protocol_schema.User) -> None:
user = self.lookup_client_user(client_user=client_user, create_missing=False)
if user is None:
raise OasstError("User not found", OasstErrorCode.USER_NOT_FOUND, HTTP_404_NOT_FOUND)
def update_user_last_activity(self, user: User) -> None:
user.last_activity_date = utcnow()
self.db.add(user)
@@ -203,3 +203,64 @@ def managed_tx_function(
return wrapped_f
return decorator
def async_managed_tx_function(
auto_commit: CommitMode = CommitMode.COMMIT,
num_retries=settings.DATABASE_MAX_TX_RETRY_COUNT,
session_factory: Callable[..., Session] = default_session_factor,
):
"""Passes Session object as first argument to wrapped function."""
def decorator(f):
@wraps(f)
async def wrapped_f(*args, **kwargs):
try:
result = None
if auto_commit == CommitMode.COMMIT:
retry_exhausted = True
for i in range(num_retries):
with session_factory() as session:
try:
result = await f(session, *args, **kwargs)
session.commit()
if isinstance(result, SQLModel):
session.refresh(result)
retry_exhausted = False
break
except PendingRollbackError as e:
logger.info(str(e))
session.rollback()
except OperationalError as e:
if e.orig is not None and isinstance(
e.orig,
(SerializationFailure, DeadlockDetected, UniqueViolation, ExclusionViolation),
):
logger.info(f"{type(e.orig)} Inner {e.orig.pgcode} {type(e.orig.pgcode)}")
session.rollback()
else:
raise e
logger.info(f"Retry {i+1}/{num_retries}")
if retry_exhausted:
raise OasstError(
"DATABASE_MAX_RETIRES_EXHAUSTED",
error_code=OasstErrorCode.DATABASE_MAX_RETRIES_EXHAUSTED,
http_status_code=HTTPStatus.SERVICE_UNAVAILABLE,
)
else:
with session_factory() as session:
result = await f(session, *args, **kwargs)
if auto_commit == CommitMode.FLUSH:
session.flush()
if isinstance(result, SQLModel):
session.refresh(result)
elif auto_commit == CommitMode.ROLLBACK:
session.rollback()
return result
except Exception as e:
logger.info(str(e))
raise e
return wrapped_f
return decorator