@@ -1,3 +1,4 @@ | |||||
import PIL.Image | |||||
import requests | import requests | ||||
from io import BytesIO | from io import BytesIO | ||||
@@ -20,6 +21,19 @@ class ImageManager(models.Manager): | |||||
'Chrome/48.0.2564.82 Safari/537.36', | 'Chrome/48.0.2564.82 Safari/537.36', | ||||
} | } | ||||
@staticmethod | |||||
def _is_valid_image(fp): | |||||
fp.seek(0) | |||||
try: | |||||
PIL.Image.open(fp) | |||||
except PIL.UnidentifiedImageError: | |||||
fp.seek(0) | |||||
return False | |||||
else: | |||||
fp.seek(0) | |||||
return True | |||||
# FIXME: Move this into an asynchronous task | # FIXME: Move this into an asynchronous task | ||||
def create_for_url(self, url, referer=None): | def create_for_url(self, url, referer=None): | ||||
file_name = url.split("/")[-1].split('#')[0].split('?')[0] | file_name = url.split("/")[-1].split('#')[0].split('?')[0] | ||||
@@ -29,6 +43,8 @@ class ImageManager(models.Manager): | |||||
headers["Referer"] = referer | headers["Referer"] = referer | ||||
response = requests.get(url, headers=headers) | response = requests.get(url, headers=headers) | ||||
buf.write(response.content) | buf.write(response.content) | ||||
if not self._is_valid_image(buf): | |||||
return None | |||||
obj = InMemoryUploadedFile(buf, 'image', file_name, | obj = InMemoryUploadedFile(buf, 'image', file_name, | ||||
None, buf.tell(), None) | None, buf.tell(), None) | ||||
# create the image and its thumbnails in one transaction, removing | # create the image and its thumbnails in one transaction, removing | ||||
@@ -130,6 +130,8 @@ class PinSerializer(serializers.HyperlinkedModelSerializer): | |||||
url, | url, | ||||
validated_data.get('referer', url), | validated_data.get('referer', url), | ||||
) | ) | ||||
if not image: | |||||
raise ValidationError({"url": "invalid image content"}) | |||||
else: | else: | ||||
image = validated_data.pop("image_by_id") | image = validated_data.pop("image_by_id") | ||||
tags = validated_data.pop('tag_list', []) | tags = validated_data.pop('tag_list', []) | ||||
@@ -23,6 +23,11 @@ def mock_requests_get(url, **kwargs): | |||||
return response | return response | ||||
def mock_requests_get_with_non_image_content(url, **kwargs): | |||||
response = mock.Mock(content=b"abcd") | |||||
return response | |||||
class ImageTests(APITestCase): | class ImageTests(APITestCase): | ||||
def test_post_create_unsupported(self): | def test_post_create_unsupported(self): | ||||
url = reverse("image-list") | url = reverse("image-list") | ||||
@@ -119,7 +124,7 @@ class PinPrivacyTests(APITestCase): | |||||
self.non_owner = create_user("non_owner") | self.non_owner = create_user("non_owner") | ||||
with mock.patch('requests.get', mock_requests_get): | with mock.patch('requests.get', mock_requests_get): | ||||
image = Image.objects.create_for_url('http://a.com/b.png') | |||||
image = create_image() | |||||
self.private_pin = Pin.objects.create( | self.private_pin = Pin.objects.create( | ||||
submitter=self.owner, | submitter=self.owner, | ||||
image=image, | image=image, | ||||
@@ -175,6 +180,20 @@ class PinTests(APITestCase): | |||||
def tearDown(self): | def tearDown(self): | ||||
_teardown_models() | _teardown_models() | ||||
@mock.patch('requests.get', mock_requests_get_with_non_image_content) | |||||
def test_should_not_create_pin_if_url_content_invalid(self): | |||||
url = 'http://testserver.com/mocked/logo-01.png' | |||||
create_url = reverse("pin-list") | |||||
referer = 'http://testserver.com/' | |||||
post_data = { | |||||
'url': url, | |||||
'private': False, | |||||
'referer': referer, | |||||
'description': 'That\'s an Apple!' | |||||
} | |||||
response = self.client.post(create_url, data=post_data, format="json") | |||||
self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) | |||||
@mock.patch('requests.get', mock_requests_get) | @mock.patch('requests.get', mock_requests_get) | ||||
def test_should_create_pin(self): | def test_should_create_pin(self): | ||||
url = 'http://testserver.com/mocked/logo-01.png' | url = 'http://testserver.com/mocked/logo-01.png' | ||||