diff --git a/flask_security/forms.py b/flask_security/forms.py index 44db7a7..8367549 100644 --- a/flask_security/forms.py +++ b/flask_security/forms.py @@ -244,8 +244,12 @@ class ConfirmRegisterForm(Form, RegisterFormMixin, pass -class RegisterForm(ConfirmRegisterForm, PasswordConfirmFormMixin): - pass +class RegisterForm(ConfirmRegisterForm, PasswordConfirmFormMixin, + NextFormMixin): + def __init__(self, *args, **kwargs): + super(RegisterForm, self).__init__(*args, **kwargs) + if not self.next.data: + self.next.data = request.args.get('next', '') class ResetPasswordForm(Form, NewPasswordFormMixin, PasswordConfirmFormMixin): diff --git a/flask_security/utils.py b/flask_security/utils.py index f519b3d..1fd7614 100644 --- a/flask_security/utils.py +++ b/flask_security/utils.py @@ -234,8 +234,8 @@ def get_post_login_redirect(declared=None): return get_post_action_redirect('SECURITY_POST_LOGIN_VIEW', declared) -def get_post_register_redirect(): - return get_post_action_redirect('SECURITY_POST_REGISTER_VIEW') +def get_post_register_redirect(declared=None): + return get_post_action_redirect('SECURITY_POST_REGISTER_VIEW', declared) def find_redirect(key): diff --git a/flask_security/views.py b/flask_security/views.py index be80372..4786af3 100644 --- a/flask_security/views.py +++ b/flask_security/views.py @@ -122,7 +122,7 @@ def register(): login_user(user) if not request.json: - return redirect(get_post_register_redirect()) + return redirect(get_post_register_redirect(form.next.data)) return _render_json(form, include_auth_token=True) if request.json: