diff --git a/flask_security/views.py b/flask_security/views.py index 33e6ab9..0deb68c 100644 --- a/flask_security/views.py +++ b/flask_security/views.py @@ -96,22 +96,33 @@ def logout(): def register(): """View function which handles a registration request.""" - if _security.confirmable: - form = ConfirmRegisterForm() + if _security.confirmable or request.json: + form_class = ConfirmRegisterForm else: - form = RegisterForm() + form_class = RegisterForm + + if request.json: + form_data = MultiDict(request.json) + else: + form_data = None + + form = form_class(form_data) if form.validate_on_submit(): user = register_user(**form.to_dict()) + form.user = user if not _security.confirmable or _security.login_without_confirmation: after_this_request(_commit) login_user(user) - post_register_url = get_url(_security.post_register_view) - post_login_url = get_url(_security.post_login_view) + if not request.json: + post_register_url = get_url(_security.post_register_view) + post_login_url = get_url(_security.post_login_view) + return redirect(post_register_url or post_login_url) - return redirect(post_register_url or post_login_url) + if request.json: + return _render_json(form) return render_template('security/register_user.html', register_user_form=form, @@ -174,7 +185,7 @@ def confirm_email(token): """View function which handles a email confirmation request.""" expired, invalid, user = confirm_email_token_status(token) - print expired, invalid, user + if invalid: do_flash(*get_message('INVALID_CONFIRMATION_TOKEN')) if expired: diff --git a/tests/functional_tests.py b/tests/functional_tests.py index eb0c5c4..1595825 100644 --- a/tests/functional_tests.py +++ b/tests/functional_tests.py @@ -117,7 +117,9 @@ class DefaultSecurityTests(SecurityTest): def test_ok_json_auth(self): r = self.json_authenticate() - self.assertIn('"code": 200', r.data) + data = json.loads(r.data) + self.assertEquals(data['meta']['code'], 200) + self.assertIn('authentication_token', data['response']['user']) def test_invalid_json_auth(self): r = self.json_authenticate(password='junk') @@ -250,6 +252,15 @@ class ConfiguredSecurityTests(SecurityTest): r = self._post('/register', data=data, follow_redirects=True) self.assertIn('Post Register', r.data) + def test_register_json(self): + r = self._post('/register', + data='{ "email": "dude@lp.com", "password": "password" }', + content_type='application/json') + data = json.loads(r.data) + print data + self.assertEquals(data['meta']['code'], 200) + self.assertIn('authentication_token', data['response']['user']) + def test_register_existing_email(self): data = dict(email='matt@lp.com', password='password',