@@ -1,14 +1,34 @@ | |||
from api import models | |||
from rest_framework import serializers | |||
from rest_framework_simplejwt.serializers import TokenObtainPairSerializer | |||
class PasswordSerializer(serializers.ModelSerializer): | |||
class Meta: | |||
model = models.Password | |||
fields = ('id', 'login', 'site', 'lowercase', 'uppercase', 'symbols', 'numbers', 'counter', 'length', | |||
'version', 'created', 'modified') | |||
read_only_fields = ('created', 'modified') | |||
fields = ( | |||
"id", | |||
"login", | |||
"site", | |||
"lowercase", | |||
"uppercase", | |||
"symbols", | |||
"numbers", | |||
"counter", | |||
"length", | |||
"version", | |||
"created", | |||
"modified", | |||
) | |||
read_only_fields = ("created", "modified") | |||
def create(self, validated_data): | |||
user = self.context['request'].user | |||
user = self.context["request"].user | |||
return models.Password.objects.create(user=user, **validated_data) | |||
class BackwardCompatibleTokenObtainPairSerializer(TokenObtainPairSerializer): | |||
def validate(self, attrs): | |||
data = super().validate(attrs) | |||
data.update({"token": data["access"]}) | |||
return data |
@@ -0,0 +1,51 @@ | |||
from rest_framework.test import APITestCase, APIClient | |||
from api import models | |||
from api.tests import factories | |||
class OldRegisterTestCase(APITestCase): | |||
def test_register(self): | |||
self.assertEqual(0, models.LessPassUser.objects.all().count()) | |||
data = { | |||
"email": "contact@example.org", | |||
"password": "correct horse battery staple", | |||
} | |||
request = self.client.post("/api/auth/register/", data) | |||
self.assertEqual(request.status_code, 201) | |||
self.assertEqual(1, models.LessPassUser.objects.all().count()) | |||
def test_register_404_weak_password(self): | |||
self.assertEqual(0, models.LessPassUser.objects.all().count()) | |||
data = { | |||
"email": "contact@example.org", | |||
"password": "password", | |||
} | |||
request = self.client.post("/api/auth/register/", data) | |||
self.assertEqual(request.status_code, 400) | |||
self.assertEqual(0, models.LessPassUser.objects.all().count()) | |||
class OldLoginTestCase(APITestCase): | |||
def test_login(self): | |||
user = factories.UserFactory( | |||
email="contact@example.org", password="correct horse battery staple" | |||
) | |||
data = { | |||
"email": "contact@example.org", | |||
"password": "correct horse battery staple", | |||
} | |||
request = self.client.post("/api/tokens/auth/", data) | |||
self.assertEqual(request.status_code, 200) | |||
self.assertIsNotNone(request.data["token"]) | |||
def test_login_bad_password(self): | |||
user = factories.UserFactory( | |||
email="contact@example.org", password="correct horse battery staple" | |||
) | |||
data = { | |||
"email": "contact@example.org", | |||
"password": "not the good password", | |||
} | |||
request = self.client.post("/api/tokens/auth/", data) | |||
self.assertEqual(request.status_code, 401) |
@@ -4,13 +4,13 @@ from api import models | |||
from api.tests import factories | |||
class LogoutApiTestCase(APITestCase): | |||
class LogoutPasswordsTestCase(APITestCase): | |||
def test_get_passwords_401(self): | |||
response = self.client.get('/api/passwords/') | |||
self.assertEqual(401, response.status_code) | |||
class LoginApiTestCase(APITestCase): | |||
class LoginPasswordsTestCase(APITestCase): | |||
def setUp(self): | |||
self.user = factories.UserFactory() | |||
self.client = APIClient() | |||
@@ -1,4 +1,5 @@ | |||
import rest_framework_simplejwt.views | |||
import djoser.views | |||
from django.urls import include, path | |||
from rest_framework.routers import DefaultRouter | |||
@@ -6,11 +7,10 @@ from api import views | |||
router = DefaultRouter() | |||
router.register(r"passwords", views.PasswordViewSet, basename="passwords") | |||
router.register(r"auth/register", djoser.views.UserViewSet, basename="auth_register") | |||
urlpatterns = [ | |||
path("", include(router.urls)), | |||
path("tokens/auth/", rest_framework_simplejwt.views.token_obtain_pair), | |||
path("tokens/refresh/", rest_framework_simplejwt.views.token_refresh), | |||
path("tokens/auth/", views.BackwardCompatibleTokenObtainPairView.as_view()), | |||
path("auth/", include("djoser.urls")), | |||
path("auth/", include("djoser.urls.jwt")), | |||
] |
@@ -2,13 +2,26 @@ from api import models, serializers | |||
from api.permissions import IsOwner | |||
from rest_framework import permissions, viewsets | |||
from rest_framework_simplejwt.views import TokenObtainPairView | |||
class PasswordViewSet(viewsets.ModelViewSet): | |||
serializer_class = serializers.PasswordSerializer | |||
permission_classes = (permissions.IsAuthenticated, IsOwner,) | |||
search_fields = ('site', 'email',) | |||
ordering_fields = ('site', 'email', 'created') | |||
permission_classes = ( | |||
permissions.IsAuthenticated, | |||
IsOwner, | |||
) | |||
search_fields = ( | |||
"site", | |||
"email", | |||
) | |||
ordering_fields = ("site", "email", "created") | |||
def get_queryset(self): | |||
return models.Password.objects.filter(user=self.request.user) | |||
class BackwardCompatibleTokenObtainPairView(TokenObtainPairView): | |||
serializer_class = serializers.BackwardCompatibleTokenObtainPairSerializer | |||
token_obtain_pair = TokenObtainPairView.as_view() |