diff --git a/flask_security/forms.py b/flask_security/forms.py
index 44db7a7..8367549 100644
--- a/flask_security/forms.py
+++ b/flask_security/forms.py
@@ -244,8 +244,12 @@ class ConfirmRegisterForm(Form, RegisterFormMixin,
pass
-class RegisterForm(ConfirmRegisterForm, PasswordConfirmFormMixin):
- pass
+class RegisterForm(ConfirmRegisterForm, PasswordConfirmFormMixin,
+ NextFormMixin):
+ def __init__(self, *args, **kwargs):
+ super(RegisterForm, self).__init__(*args, **kwargs)
+ if not self.next.data:
+ self.next.data = request.args.get('next', '')
class ResetPasswordForm(Form, NewPasswordFormMixin, PasswordConfirmFormMixin):
diff --git a/flask_security/templates/security/_menu.html b/flask_security/templates/security/_menu.html
index 5291f80..1917b72 100644
--- a/flask_security/templates/security/_menu.html
+++ b/flask_security/templates/security/_menu.html
@@ -1,9 +1,9 @@
{% if security.registerable or security.recoverable or security.confirmabled %}
Menu
- - Login
+ - Login
{% if security.registerable %}
- - Register
+ - Register
{% endif %}
{% if security.recoverable %}
- Forgot password
diff --git a/flask_security/utils.py b/flask_security/utils.py
index f519b3d..1fd7614 100644
--- a/flask_security/utils.py
+++ b/flask_security/utils.py
@@ -234,8 +234,8 @@ def get_post_login_redirect(declared=None):
return get_post_action_redirect('SECURITY_POST_LOGIN_VIEW', declared)
-def get_post_register_redirect():
- return get_post_action_redirect('SECURITY_POST_REGISTER_VIEW')
+def get_post_register_redirect(declared=None):
+ return get_post_action_redirect('SECURITY_POST_REGISTER_VIEW', declared)
def find_redirect(key):
diff --git a/flask_security/views.py b/flask_security/views.py
index be80372..864c9e3 100644
--- a/flask_security/views.py
+++ b/flask_security/views.py
@@ -122,7 +122,12 @@ def register():
login_user(user)
if not request.json:
- return redirect(get_post_register_redirect())
+ if 'next' in form:
+ redirect_url = get_post_register_redirect(form.next.data)
+ else:
+ redirect_url = get_post_register_redirect()
+
+ return redirect(redirect_url)
return _render_json(form, include_auth_token=True)
if request.json:
diff --git a/tests/test_confirmable.py b/tests/test_confirmable.py
index f1c0b9e..8fddfdc 100644
--- a/tests/test_confirmable.py
+++ b/tests/test_confirmable.py
@@ -35,7 +35,8 @@ def test_confirmable_flag(app, client, sqlalchemy_datastore, get_message):
email = 'dude@lp.com'
with capture_registrations() as registrations:
- response = client.post('/register', data=dict(email=email, password='password'))
+ data = dict(email=email, password='password', next='')
+ response = client.post('/register', data=data)
assert response.status_code == 302
@@ -85,7 +86,8 @@ def test_confirmable_flag(app, client, sqlalchemy_datastore, get_message):
# Test user was deleted before confirmation
with capture_registrations() as registrations:
- client.post('/register', data=dict(email='mary@lp.com', password='password'))
+ data = dict(email='mary@lp.com', password='password', next='')
+ client.post('/register', data=data)
user = registrations[0]['user']
token = registrations[0]['confirm_token']
@@ -102,7 +104,7 @@ def test_confirmable_flag(app, client, sqlalchemy_datastore, get_message):
@pytest.mark.settings(confirm_email_within='1 milliseconds')
def test_expired_confirmation_token(client, get_message):
with capture_registrations() as registrations:
- data = dict(email='mary@lp.com', password='password')
+ data = dict(email='mary@lp.com', password='password', next='')
client.post('/register', data=data, follow_redirects=True)
user = registrations[0]['user']
@@ -118,7 +120,7 @@ def test_expired_confirmation_token(client, get_message):
@pytest.mark.registerable()
@pytest.mark.settings(login_without_confirmation=True)
def test_login_when_unconfirmed(client, get_message):
- data = dict(email='mary@lp.com', password='password')
+ data = dict(email='mary@lp.com', password='password', next='')
response = client.post('/register', data=data, follow_redirects=True)
assert b'mary@lp.com' in response.data
@@ -131,7 +133,8 @@ def test_confirmation_different_user_when_logged_in(client, get_message):
with capture_registrations() as registrations:
for e in e1, e2:
- client.post('/register', data=dict(email=e, password='password'))
+ data = dict(email=e, password='password', next='')
+ client.post('/register', data=data)
logout(client)
token1 = registrations[0]['confirm_token']
diff --git a/tests/test_registerable.py b/tests/test_registerable.py
index 542e1e4..02b5c5f 100644
--- a/tests/test_registerable.py
+++ b/tests/test_registerable.py
@@ -28,7 +28,10 @@ def test_registerable_flag(client, app, get_message):
def on_user_registerd(app, user, confirm_token):
recorded.append(user)
- data = dict(email='dude@lp.com', password='password', password_confirm='password')
+ data = dict(
+ email='dude@lp.com', password='password', password_confirm='password',
+ next=''
+ )
with app.mail.record_messages() as outbox:
response = client.post('/register', data=data, follow_redirects=True)
@@ -45,7 +48,10 @@ def test_registerable_flag(client, app, get_message):
logout(client)
# Test registering with an existing email
- data = dict(email='dude@lp.com', password='password', password_confirm='password')
+ data = dict(
+ email='dude@lp.com', password='password', password_confirm='password',
+ next=''
+ )
response = client.post('/register', data=data, follow_redirects=True)
assert get_message('EMAIL_ALREADY_ASSOCIATED', email='dude@lp.com') in response.data
@@ -68,7 +74,8 @@ def test_registerable_flag(client, app, get_message):
# Test ?next param
data = dict(email='dude3@lp.com',
password='password',
- password_confirm='password')
+ password_confirm='password',
+ next='')
response = client.post('/register?next=/page1', data=data, follow_redirects=True)
assert b'Page 1' in response.data
@@ -81,7 +88,8 @@ def test_custom_register_url(client):
data = dict(email='dude@lp.com',
password='password',
- password_confirm='password')
+ password_confirm='password',
+ next='')
response = client.post('/custom_register', data=data, follow_redirects=True)
assert b'Post Register' in response.data
@@ -95,7 +103,10 @@ def test_custom_register_tempalate(client):
@pytest.mark.settings(send_register_email=False)
def test_disable_register_emails(client, app):
- data = dict(email='dude@lp.com', password='password', password_confirm='password')
+ data = dict(
+ email='dude@lp.com', password='password', password_confirm='password',
+ next=''
+ )
with app.mail.record_messages() as outbox:
client.post('/register', data=data, follow_redirects=True)
assert len(outbox) == 0