mirror of
https://github.com/wassname/Open-Assistant.git
synced 2026-06-27 16:10:30 +08:00
90c3d5640e
* added db for inference * fixed dockerfiles for inference
42 lines
1.1 KiB
Python
42 lines
1.1 KiB
Python
import json
|
|
|
|
import pydantic.json
|
|
import sqlmodel
|
|
from loguru import logger
|
|
from oasst_inference_server import models
|
|
from oasst_inference_server.settings import settings
|
|
|
|
|
|
def default_json_serializer(obj):
|
|
class_name = obj.__class__.__name__
|
|
encoded = pydantic.json.pydantic_encoder(obj)
|
|
encoded["_classname_"] = class_name
|
|
return encoded
|
|
|
|
|
|
def custom_json_serializer(obj):
|
|
return json.dumps(obj, default=default_json_serializer)
|
|
|
|
|
|
def custom_json_deserializer(s):
|
|
d = json.loads(s)
|
|
if not isinstance(d, dict):
|
|
return d
|
|
match d.get("_classname_"):
|
|
case "Conversation":
|
|
return models.protocol.Conversation.parse_obj(d)
|
|
case "MessageRequest":
|
|
return models.interface.MessageRequest.parse_obj(d)
|
|
case None:
|
|
return d
|
|
case _:
|
|
logger.error(f"Unknown class {d['_classname_']}")
|
|
raise ValueError(f"Unknown class {d['_classname_']}")
|
|
|
|
|
|
db_engine = sqlmodel.create_engine(
|
|
settings.database_uri,
|
|
json_serializer=custom_json_serializer,
|
|
json_deserializer=custom_json_deserializer,
|
|
)
|