您最多选择25个主题 主题必须以字母或数字开头,可以包含连字符 (-),并且长度不得超过35个字符
 
 
 
 
 
 

93 行
2.8 KiB

  1. from django.conf import settings
  2. from django.contrib.auth.models import Permission
  3. from django.core.files.images import ImageFile
  4. from django.db.models.query import QuerySet
  5. from django.test import TestCase
  6. from django_images.models import Thumbnail
  7. import factory
  8. from taggit.models import Tag
  9. from core.models import Pin, Image
  10. from users.models import User
  11. TEST_IMAGE_PATH = 'logo.png'
  12. class UserFactory(factory.Factory):
  13. FACTORY_FOR = User
  14. username = factory.Sequence(lambda n: 'user_{}'.format(n))
  15. email = factory.Sequence(lambda n: 'user_{}@example.com'.format(n))
  16. @factory.post_generation(extract_prefix='password')
  17. def set_password(self, create, extracted, **kwargs):
  18. self.set_password(extracted)
  19. self.save()
  20. @factory.post_generation(extract_prefix='user_permissions')
  21. def set_user_permissions(self, create, extracted, **kwargs):
  22. self.user_permissions = Permission.objects.filter(codename__in=['add_pin', 'add_image'])
  23. class TagFactory(factory.Factory):
  24. FACTORY_FOR = Tag
  25. name = factory.Sequence(lambda n: 'tag_{}'.format(n))
  26. class ImageFactory(factory.Factory):
  27. FACTORY_FOR = Image
  28. image = factory.LazyAttribute(lambda a: ImageFile(open(TEST_IMAGE_PATH, 'rb')))
  29. @factory.post_generation()
  30. def create_thumbnails(self, create, extracted, **kwargs):
  31. for size in settings.IMAGE_SIZES.keys():
  32. Thumbnail.objects.get_or_create_at_size(self.pk, size)
  33. class PinFactory(factory.Factory):
  34. FACTORY_FOR = Pin
  35. submitter = factory.SubFactory(UserFactory)
  36. image = factory.SubFactory(ImageFactory)
  37. @factory.post_generation(extract_prefix='tags')
  38. def add_tags(self, create, extracted, **kwargs):
  39. if isinstance(extracted, Tag):
  40. self.tags.add(extracted)
  41. elif isinstance(extracted, list):
  42. self.tags.add(*extracted)
  43. elif isinstance(extracted, QuerySet):
  44. self.tags = extracted
  45. else:
  46. self.tags.add(TagFactory())
  47. class PinFactoryTest(TestCase):
  48. def test_default_tags(self):
  49. tags = PinFactory.create().tags.all()
  50. self.assertTrue(all([tag.name.startswith('tag_') for tag in tags]))
  51. self.assertEqual(tags.count(), 1)
  52. def test_custom_tag(self):
  53. custom = 'custom_tag'
  54. self.assertEqual(PinFactory(tags=Tag.objects.create(name=custom)).tags.get(pk=1).name, custom)
  55. def test_custom_tags_list(self):
  56. tags = TagFactory.create_batch(2)
  57. PinFactory(tags=tags)
  58. self.assertEqual(Tag.objects.count(), 2)
  59. def test_custom_tags_queryset(self):
  60. TagFactory.create_batch(2)
  61. tags = Tag.objects.all()
  62. PinFactory(tags=tags)
  63. self.assertEqual(Tag.objects.count(), 2)
  64. def test_empty_tags(self):
  65. PinFactory(tags=[])
  66. self.assertEqual(Tag.objects.count(), 0)