diff --git a/flask_security/__init__.py b/flask_security/__init__.py index dfce324..3464c24 100644 --- a/flask_security/__init__.py +++ b/flask_security/__init__.py @@ -10,6 +10,8 @@ :license: MIT, see LICENSE for more details. """ +from flask.ext.login import login_user, logout_user + from .core import Security, RoleMixin, UserMixin, AnonymousUser, \ AuthenticationProvider, current_user from .decorators import auth_token_required, http_auth_required, \ diff --git a/flask_security/forms.py b/flask_security/forms.py index 46cc452..1b780c7 100644 --- a/flask_security/forms.py +++ b/flask_security/forms.py @@ -14,51 +14,73 @@ from flask.ext.wtf import Form, TextField, PasswordField, SubmitField, \ HiddenField, Required, BooleanField, EqualTo, Email -class ForgotPasswordForm(Form): +class EmailFormMixin(): email = TextField("Email Address", - validators=[Required(message="Email not provided")]) + validators=[Required(message="Email not provided"), + Email(message="Invalid email address")]) + + +class PasswordFormMixin(): + password = PasswordField("Password", + validators=[Required(message="Password not provided")]) + + +class PasswordConfirmFormMixin(): + password_confirm = PasswordField("Retype Password", + validators=[EqualTo('password', message="Passwords do not match")]) + + +class ForgotPasswordForm(Form, EmailFormMixin): + """The default forgot password form""" + + submit = SubmitField("Recover Password") def to_dict(self): return dict(email=self.email.data) -class LoginForm(Form): +class LoginForm(Form, EmailFormMixin, PasswordFormMixin): """The default login form""" - email = TextField("Email Address", - validators=[Required(message="Email not provided")]) - password = PasswordField("Password", - validators=[Required(message="Password not provided")]) remember = BooleanField("Remember Me") next = HiddenField() submit = SubmitField("Login") def __init__(self, *args, **kwargs): super(LoginForm, self).__init__(*args, **kwargs) - self.next.data = request.args.get('next', None) + + if request.method == 'GET': + self.next.data = request.args.get('next', None) -class RegisterForm(Form): +class RegisterForm(Form, + EmailFormMixin, + PasswordFormMixin, + PasswordConfirmFormMixin): """The default register form""" - email = TextField("Email Address", - validators=[Required(message='Email not provided'), Email()]) - password = PasswordField("Password", - validators=[Required(message="Password not provided")]) - password_confirm = PasswordField("Retype Password", - validators=[EqualTo('password', message="Passwords do not match")]) + submit = SubmitField("Register") def to_dict(self): return dict(email=self.email.data, password=self.password.data) -class ResetPasswordForm(Form): +class ResetPasswordForm(Form, + EmailFormMixin, + PasswordFormMixin, + PasswordConfirmFormMixin): + """The default reset password form""" + token = HiddenField(validators=[Required()]) - email = HiddenField(validators=[Required()]) - password = PasswordField("Password", - validators=[Required(message="Password not provided")]) - password_confirm = PasswordField("Retype Password", - validators=[EqualTo('password', message="Passwords do not match")]) + + submit = SubmitField("Reset Password") + + def __init__(self, *args, **kwargs): + super(ResetPasswordForm, self).__init__(*args, **kwargs) + + if request.method == 'GET': + self.token.data = request.args.get('token', None) + self.email.data = request.args.get('email', None) def to_dict(self): return dict(token=self.token.data, diff --git a/flask_security/recoverable.py b/flask_security/recoverable.py index 9a2e4cd..a53144a 100644 --- a/flask_security/recoverable.py +++ b/flask_security/recoverable.py @@ -35,6 +35,7 @@ def find_user_by_reset_token(token): def send_reset_password_instructions(user): url = url_for('flask_security.reset', + email=user.email, reset_token=user.reset_password_token) reset_link = request.url_root[:-1] + url diff --git a/tests/functional_tests.py b/tests/functional_tests.py index 3a728d6..0905faf 100644 --- a/tests/functional_tests.py +++ b/tests/functional_tests.py @@ -27,8 +27,12 @@ class DefaultSecurityTests(SecurityTest): r = self.authenticate(password="") self.assertIn("Password not provided", r.data) - def test_invalid_user(self): + def test_invalid_email(self): r = self.authenticate(email="bogus") + self.assertIn("Invalid email address", r.data) + + def test_invalid_user(self): + r = self.authenticate(email="bogus@bogus.com") self.assertIn("Specified user does not exist", r.data) def test_bad_password(self):