-
Notifications
You must be signed in to change notification settings - Fork 516
/
video.py
1406 lines (1233 loc) · 52.5 KB
/
video.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
#
# For licensing see accompanying LICENSE file.
# Copyright (C) 2024 Apple Inc. All Rights Reserved.
#
import argparse
import datetime
import hashlib
import math
import os
import random
import subprocess
import tempfile
from pathlib import Path
from typing import Callable, Dict, List, Optional, Sequence, Tuple, Union
import psutil
import torch
import torchaudio
from torch.nn import functional as F
from torchvision.io import write_video
from torchvision.transforms import InterpolationMode
from torchvision.transforms import functional as FV
from corenet.data.transforms import TRANSFORMATIONS_REGISTRY, BaseTransformation
from corenet.data.transforms.utils import *
from corenet.options.parse_args import JsonValidator
from corenet.utils import logger
SUPPORTED_PYTORCH_INTERPOLATIONS = ["nearest", "bilinear", "bicubic"]
def _check_interpolation(interpolation):
if interpolation not in SUPPORTED_PYTORCH_INTERPOLATIONS:
inter_str = "Supported interpolation modes are:"
for i, j in enumerate(SUPPORTED_PYTORCH_INTERPOLATIONS):
inter_str += "\n\t{}: {}".format(i, j)
logger.error(inter_str)
return interpolation
def _crop_fn(data: Dict, i: int, j: int, h: int, w: int) -> Dict:
"""Crop the video in `data`.
Args:
data: A dictionary of data. The format is:
{
"samples":{
"video": A video tensor of shape [...x H x W], where H and W are the
height and width.
"audio": An audio tensor.
}
}
i: The height coordinate of the top left corner of the cropped rectangle.
j: The width coordinate of the top left corner of the cropped rectangle.
h: The height of the cropped rectangle.
w: The width of the cropped rectangle.
Returns:
A dictionary of the same format as `data` where `data["samples"]["videos"]` is
the cropped video.
"""
img = data["samples"]["video"]
check_rgb_video_tensor(img)
crop_image = img[..., i : i + h, j : j + w]
data["samples"]["video"] = crop_image
mask = data.get("mask", None)
if mask is not None:
crop_mask = mask[..., i : i + h, j : j + w]
data["samples"]["mask"] = crop_mask
return data
def _resize_fn(
data: Dict,
size: Union[Sequence, int],
interpolation: Optional[str] = "bilinear",
) -> Dict:
"""Resize the video in `data`.
Args:
data: A dictionary of data. The format is:
{
"samples":{
"video": A video tensor of shape [... x H x W], where H and W are the
height and width.
"mask": An optional entry of the mask tensor of shape [... x H x W],
where H and W are the height and width.
"audio": An audio tensor.
}
}
size: The size of video to resize to.
interpolation: The method of interpolation to use. Choices are: "bilinear",
"nearest", "linear", "bicubic", "trilinear", "area", "nearest-exact", default to
"bilinear".
Returns:
A dictionary of the same format as `data` where `data["samples"]["videos"]` is
the cropped video.
"""
video = data["samples"]["video"]
if isinstance(size, Sequence) and len(size) == 2:
size_h, size_w = size[0], size[1]
elif isinstance(size, int):
h, w = video.shape[-2:]
if (w <= h and w == size) or (h <= w and h == size):
return data
if w < h:
size_h = int(size * h / w)
size_w = size
else:
size_w = int(size * w / h)
size_h = size
else:
raise TypeError(
"Supported size args are int or tuple of length 2. Got inappropriate size"
" arg: {}".format(size)
)
if isinstance(interpolation, str):
interpolation = _check_interpolation(interpolation)
n, tc1, tc2, h, w = video.shape
# Since video could be either NTCHW or NCTHW format, we reshape the 5D tensor into
# 4D and transpose back to 5D.
video = F.interpolate(
input=video.reshape(n, tc1 * tc2, h, w),
size=(size_h, size_w),
mode=interpolation,
align_corners=True if interpolation != "nearest" else None,
)
data["samples"]["video"] = video.reshape(n, tc1, tc2, size_h, size_w)
mask = data["samples"].get("mask", None)
if mask is not None:
mask = F.interpolate(input=mask, size=(size_h, size_w), mode="nearest")
data["samples"]["mask"] = mask
return data
def check_rgb_video_tensor(clip: torch.Tensor) -> None:
"""Check if the video tensor is the right type and shape.
Args:
clip: A video clip tensor of shape [N x C x T x H x W] or
[N x C x T x H x W], where N is the number of clips, T is the number
of frames of the clip, C is the number of image channels,
H and W are the height and width of the frame image.
"""
if not isinstance(clip, torch.FloatTensor):
logger.error("Video clip is not an instance of FloatTensor.")
if clip.dim() != 5:
logger.error("Video clip is not a 5-d tensor (NTCHW or NCTHW).")
@TRANSFORMATIONS_REGISTRY.register(name="to_tensor", type="video")
class ToTensor(BaseTransformation):
"""
This method converts an image into a tensor.
Tensor shape abbreviations:
N: Number of clips.
T, T_audio, T_video: Temporal lengths.
C: Number of color channels.
H, W: Height, Width.
.. note::
We do not perform any mean-std normalization. If mean-std normalization is
desired, please modify this class.
"""
def __init__(self, opts: argparse.Namespace, *args, **kwargs) -> None:
super().__init__(opts=opts)
def __call__(self, data: Dict) -> Dict:
# [N, C, T, H, W] or [N, T, C, H, W].
clip = data["samples"]["video"]
if not isinstance(clip, torch.Tensor):
clip = torch.from_numpy(clip)
if not isinstance(clip, torch.FloatTensor):
# Convert to float, and normalize between 0 and 1.
clip = clip / 255.0
check_rgb_video_tensor(clip)
data["samples"]["video"] = clip
return data
@TRANSFORMATIONS_REGISTRY.register(name="to_array", type="video")
class ToPixelArray(BaseTransformation):
"""
This method is an inverse of ToTensor, converting a float tensor in range [0,1] back
to a numpy uint8 array in range [0,255].
Tensor shape abbreviations:
N: Number of clips.
T: Temporal length.
C: Number of color channels.
H, W: Height, Width.
"""
def __call__(self, data: Dict) -> Dict:
# [N, C, T, H, W] or [N, T, C, H, W].
video = data["samples"]["video"]
video = (video * 255.0).round().numpy().astype(np.uint8)
data["samples"]["video"] = video
return data
@TRANSFORMATIONS_REGISTRY.register(name="save-inputs", type="video")
class SaveInputs(BaseTransformation):
def __init__(
self,
opts: argparse.Namespace,
get_frame_captions: Optional[Callable[[Dict], List[str]]] = None,
*args,
**kwargs,
) -> None:
"""Saves the clips that are returned by VideoDataset.__getitem__() to disk
for debugging use cases. This transformation operates on multiple clips that
are extracted out of a single raw video. The video and audio of the clips are
concatenated and saved into 1 video file.
1 raw input video ==> VideoDataset.__getitem__() ==>
multiple clips in data["samples"]["video"] ==> SaveInputs() ==>
1 output debugging video.
This is useful for visualizing training and/or validation videos to make
sure preprocessing logic is behaving as expected.
Args:
opts: Command line options.
get_frame_captions: If provided, this function returns a list of strings
(one string per video frame). The frame captions will be added to the
video as subtitles.
"""
self.get_frame_captions = get_frame_captions
self.enable = getattr(opts, "video_augmentation.save_inputs.enable")
save_dir = getattr(opts, "video_augmentation.save_inputs.save_dir")
if self.enable and save_dir is None:
logger.error(
"Please provide value for --video_augmentation.save-inputs.save-dir"
)
process_start_time = datetime.datetime.fromtimestamp(
psutil.Process(os.getpid()).create_time()
).strftime("%Y-%m-%d %H:%M")
self.save_dir = Path(save_dir, process_start_time).expanduser()
self.symlink_to_original = getattr(
opts, "video_augmentation.save_inputs.symlink_to_original"
)
def __call__(self, data: Dict) -> Dict:
if not self.enable:
return data
original_path = data["samples"]["metadata"]["filename"]
original_basename = os.path.basename(original_path)
original_path_hash = hashlib.md5(str(original_path).encode()).hexdigest()[:5]
output_video_path = Path(
self.save_dir,
f"{datetime.datetime.now().isoformat()[:5]}_{original_path_hash}_{original_basename}",
)
self.save_video_with_annotations(
data=data,
output_video_path=output_video_path,
)
if self.symlink_to_original:
os.symlink(
original_path,
output_video_path.with_suffix(f".original.{output_video_path.suffix}"),
)
return data
@classmethod
def add_arguments(cls, parser: argparse.ArgumentParser) -> None:
group = parser.add_argument_group(cls.__name__)
group.add_argument(
"--video-augmentation.save-inputs.save-dir",
type=str,
default=None,
help=(
"Path to the folder for saving output debugging videos. Defaults to"
" None."
),
)
group.add_argument(
"--video-augmentation.save-inputs.add-labels",
action="store_true",
default=False,
help=(
"If set, write the class label on each frame of the video. Defaults to"
" False."
),
)
group.add_argument(
"--video-augmentation.save-inputs.enable",
action="store_true",
default=False,
help=(
"Use {}. This flag is useful when you want to study the effect of"
" different transforms. Defaults to False.".format(cls.__name__)
),
)
group.add_argument(
"--video-augmentation.save-inputs.symlink-to-original",
action="store_true",
default=False,
help=(
"If True, a symlink to original video sample will be created besides"
"the saved inputs for easier debugging. Defaults to False."
),
)
return parser
@staticmethod
def _srt_format_timestamp(t: float) -> str:
t = int(t * 1000)
t, millis = divmod(t, 1000)
t, ss = divmod(t, 60)
t, mm = divmod(t, 60)
hh = t
return f"{0 if hh<10 else ''}{hh}:{0 if mm<10 else ''}{mm}:{ss},{millis:0>3}"
def save_video_with_annotations(
self,
data: Dict,
output_video_path: Path,
) -> None:
"""Save a video with audio and captions.
Args:
data: Dataset output dict. Schema: {
"samples": {
"video": Tensor[N x T X C x H x W],
"audio": Tensor[N x T_audio x C], # Optional
"audio_raw": Tensor[N x T_audio x C], # Optional - if provided,
# "audio" will be ignored.
"metadata": {
"video_fps": Union[float,int],
"audio_fps": Union[float,int],
}
}
}
output_video_path: Path for saving the video.
get_frame_captions: A callback that receives @data as input and returns a
list of captions (one string per video frame). If provided, the captions
will be added to the output video as subtitles.
"""
video = data["samples"]["video"] # N x T x C x H x W
video = video.reshape(-1, *video.shape[2:]) # (N*T) x C x H x W
video_fps = data["samples"]["metadata"]["video_fps"]
if "audio_raw" in data:
audio = data["samples"]["audio_raw"] # N x T_audio x C
else:
audio = data["samples"].get("audio") # N x T_audio x C
if audio is not None:
audio = audio.reshape(-1, *audio.shape[2:]) # N*T_audio x C
audio_fps = int(round(data["samples"]["metadata"]["audio_fps"]))
video = (video * 255).round().to(dtype=torch.uint8).cpu()
video = video.permute([0, 2, 3, 1]) # N x H x W x C
suffix = output_video_path.suffix
assert suffix in (
".mp4",
".mov",
".mkv",
), f"{suffix} format is not supported by SaveInputs yet."
output_video_path.parent.mkdir(exist_ok=True, parents=True)
if audio is not None or self.get_frame_captions is not None:
with tempfile.TemporaryDirectory() as tmp_dir:
tmp_video = Path(tmp_dir, "video" + suffix)
write_video(str(tmp_video), video_array=video, fps=video_fps)
command = ["ffmpeg", "-i", tmp_video]
if audio is not None:
tmp_audio = str(Path(tmp_dir, "audio.wav"))
torchaudio.save(tmp_audio, audio.transpose(0, 1), audio_fps)
command.extend(["-i", tmp_audio])
command.extend(["-c:v", "libx264"])
if audio is not None:
command.extend(["-c:a", "aac"])
if self.get_frame_captions:
captions = self.get_frame_captions(data)
tmp_srt = str(Path(tmp_dir, "subtitle.srt"))
with open(tmp_srt, "wt") as srt:
for i, caption in enumerate(captions):
srt.write(
f"{i+1}\n"
f"{self._srt_format_timestamp(i / video_fps)} --> "
f"{self._srt_format_timestamp((i+1) / video_fps)}\n"
f"{caption}\n\n"
)
command.extend(
[
"-vf",
f"subtitles={tmp_srt}:force_style='Alignment=6,Fontsize=48,Outline=8'",
]
)
subprocess.check_output(
[*command, f"file:{output_video_path}"],
stderr=subprocess.PIPE,
)
else:
write_video(str(output_video_path), video_array=video, fps=video_fps)
def __repr__(self) -> str:
return (
"{}(save_dir={}, add_labels={}, symlink_to_original={}, enable={})".format(
self.__class__.__name__,
self.save_dir,
self.add_labels,
self.symlink_to_original,
self.enable,
)
)
@TRANSFORMATIONS_REGISTRY.register(name="random_resized_crop", type="video")
class RandomResizedCrop(BaseTransformation):
"""
This class crops a random portion of an image and resize it to a given size.
"""
def __init__(self, opts, size: Union[Tuple, int], *args, **kwargs) -> None:
interpolation = getattr(
opts,
"video_augmentation.random_resized_crop.interpolation",
)
scale = getattr(opts, "video_augmentation.random_resized_crop.scale")
ratio = getattr(
opts,
"video_augmentation.random_resized_crop.aspect_ratio",
)
if not isinstance(scale, Sequence) or (
isinstance(scale, Sequence)
and len(scale) != 2
and 0.0 <= scale[0] < scale[1]
):
logger.error(
"--video-augmentation.random-resized-crop.scale should be a tuple of"
f" length 2 such that 0.0 <= scale[0] < scale[1]. Got: {scale}."
)
if not isinstance(ratio, Sequence) or (
isinstance(ratio, Sequence)
and len(ratio) != 2
and 0.0 < ratio[0] < ratio[1]
):
logger.error(
"--video-augmentation.random-resized-crop.aspect-ratio should be a"
f" tuple of length 2 such that 0.0 < ratio[0] < ratio[1]. Got: {ratio}."
)
ratio = (round(ratio[0], 3), round(ratio[1], 3))
super().__init__(opts=opts)
self.scale = scale
self.size = setup_size(size=size)
self.interpolation = _check_interpolation(interpolation)
self.ratio = ratio
self.enable = getattr(opts, "video_augmentation.random_resized_crop.enable")
@classmethod
def add_arguments(cls, parser: argparse.ArgumentParser) -> argparse.ArgumentParser:
group = parser.add_argument_group(title=cls.__name__)
group.add_argument(
"--video-augmentation.random-resized-crop.enable",
action="store_true",
help=(
"Use {}. This flag is useful when you want to study the effect of"
" different transforms. Defaults to False.".format(cls.__name__)
),
)
group.add_argument(
"--video-augmentation.random-resized-crop.interpolation",
type=str,
default="bilinear",
choices=SUPPORTED_PYTORCH_INTERPOLATIONS,
help="Desired interpolation method. Defaults to bilinear",
)
group.add_argument(
"--video-augmentation.random-resized-crop.scale",
type=JsonValidator(Tuple[float, float]),
default=(0.08, 1.0),
help=(
"Specifies the lower and upper bounds for the random area of the crop,"
" before resizing. The scale is defined with respect to the area of the"
" original image. Defaults to (0.08, 1.0)."
),
)
group.add_argument(
"--video-augmentation.random-resized-crop.aspect-ratio",
type=JsonValidator(Union[float, tuple]),
default=(3.0 / 4.0, 4.0 / 3.0),
help=(
"lower and upper bounds for the random aspect ratio of the crop,"
" before resizing. Defaults to (3./4., 4./3.)."
),
)
return parser
def get_params(self, height: int, width: int) -> (int, int, int, int):
area = height * width
for _ in range(10):
target_area = random.uniform(*self.scale) * area
log_ratio = (math.log(self.ratio[0]), math.log(self.ratio[1]))
aspect_ratio = math.exp(random.uniform(*log_ratio))
w = int(round(math.sqrt(target_area * aspect_ratio)))
h = int(round(math.sqrt(target_area / aspect_ratio)))
if 0 < w <= width and 0 < h <= height:
i = random.randint(0, height - h)
j = random.randint(0, width - w)
return i, j, h, w
# Fallback to central crop.
in_ratio = (1.0 * width) / height
if in_ratio < min(self.ratio):
w = width
h = int(round(w / min(self.ratio)))
elif in_ratio > max(self.ratio):
h = height
w = int(round(h * max(self.ratio)))
else: # whole image
w = width
h = height
i = (height - h) // 2
j = (width - w) // 2
return i, j, h, w
def __call__(self, data: Dict) -> Dict:
clip = data["samples"]["video"]
check_rgb_video_tensor(clip=clip)
height, width = clip.shape[-2:]
i, j, h, w = self.get_params(height=height, width=width)
data = _crop_fn(data=data, i=i, j=j, h=h, w=w)
return _resize_fn(data=data, size=self.size, interpolation=self.interpolation)
def __repr__(self) -> str:
return "{}(scale={}, ratio={}, interpolation={}, enable={})".format(
self.__class__.__name__,
self.scale,
self.ratio,
self.interpolation,
self.enable,
)
@TRANSFORMATIONS_REGISTRY.register(name="random_short_side_resize_crop", type="video")
class RandomShortSizeResizeCrop(BaseTransformation):
"""
This class first randomly resizes the input video such that shortest side is between
specified minimum and maximum values, and then crops a desired size video.
.. note::
This class assumes that the video size after resizing is greater than or equal
to the desired size.
"""
def __init__(self, opts, size: Union[Tuple, int], *args, **kwargs) -> None:
interpolation = getattr(
opts,
"video_augmentation.random_short_side_resize_crop.interpolation",
)
short_size_min = getattr(
opts,
"video_augmentation.random_short_side_resize_crop.short_side_min",
)
short_size_max = getattr(
opts,
"video_augmentation.random_short_side_resize_crop.short_side_max",
)
if short_size_min is None:
logger.error(
"Short side minimum value can't be None in {}".format(
self.__class__.__name__
)
)
if short_size_max is None:
logger.error(
"Short side maximum value can't be None in {}".format(
self.__class__.__name__
)
)
if short_size_max <= short_size_min:
logger.error(
"Short side maximum value should be >= short side minimum value in {}."
" Got: {} and {}".format(
self.__class__.__name__, short_size_max, short_size_min
)
)
super().__init__(opts=opts)
self.short_side_min = short_size_min
self.size = size
self.short_side_max = short_size_max
self.interpolation = _check_interpolation(interpolation)
self.enable = getattr(
opts, "video_augmentation.random_short_side_resize_crop.enable"
)
@classmethod
def add_arguments(cls, parser: argparse.ArgumentParser) -> argparse.ArgumentParser:
group = parser.add_argument_group(title=cls.__name__)
group.add_argument(
"--video-augmentation.random-short-side-resize-crop.enable",
action="store_true",
help=(
"Use {}. This flag is useful when you want to study the effect of"
" different transforms. Defaults to False.".format(cls.__name__)
),
)
group.add_argument(
"--video-augmentation.random-short-side-resize-crop.interpolation",
type=str,
default="bilinear",
choices=SUPPORTED_PYTORCH_INTERPOLATIONS,
help="Desired interpolation method. Defaults to bilinear",
)
group.add_argument(
"--video-augmentation.random-short-side-resize-crop.short-side-min",
type=int,
default=None,
help="Minimum value for video's shortest side. Defaults to None.",
)
group.add_argument(
"--video-augmentation.random-short-side-resize-crop.short-side-max",
type=int,
default=None,
help="Maximum value for video's shortest side. Defaults to None.",
)
return parser
def get_params(self, height, width) -> Tuple[int, int, int, int]:
th, tw = self.size
if width == tw and height == th:
return 0, 0, height, width
i = random.randint(0, height - th)
j = random.randint(0, width - tw)
return i, j, th, tw
def __call__(self, data: Dict) -> Dict:
short_dim = random.randint(self.short_side_max, self.short_side_max)
# resize the video so that shorter side is short_dim
data = _resize_fn(data, size=short_dim, interpolation=self.interpolation)
clip = data["samples"]["video"]
check_rgb_video_tensor(clip=clip)
height, width = clip.shape[-2:]
i, j, h, w = self.get_params(height=height, width=width)
# Crop the video.
return _crop_fn(data=data, i=i, j=j, h=h, w=w)
def __repr__(self) -> str:
return "{}(size={}, short_size_range=({}, {}), interpolation={}, enable={})".format(
self.__class__.__name__,
self.size,
self.short_side_min,
self.short_side_max,
self.interpolation,
self.enable,
)
@TRANSFORMATIONS_REGISTRY.register(name="random_crop", type="video")
class RandomCrop(BaseTransformation):
"""
This method randomly crops a video area.
.. note::
This class assumes that the input video size is greater than or equal to the
desired size.
"""
def __init__(self, opts, size: Union[Tuple, int], *args, **kwargs) -> None:
size = setup_size(size=size)
super().__init__(opts=opts)
self.size = size
self.enable = getattr(opts, "video_augmentation.random_crop.enable")
@classmethod
def add_arguments(cls, parser: argparse.ArgumentParser) -> argparse.ArgumentParser:
group = parser.add_argument_group(title=cls.__name__)
group.add_argument(
"--video-augmentation.random-crop.enable",
action="store_true",
help=(
"Use {}. This flag is useful when you want to study the effect of"
" different transforms. Defaults to False.".format(cls.__name__)
),
)
return parser
def get_params(self, height: int, width: int) -> Tuple[int, int, int, int]:
th, tw = self.size
if width == tw and height == th:
return 0, 0, height, width
i = random.randint(0, height - th)
j = random.randint(0, width - tw)
return i, j, th, tw
def __call__(self, data: Dict) -> Dict:
clip = data["samples"]["video"]
check_rgb_video_tensor(clip=clip)
height, width = clip.shape[-2:]
i, j, h, w = self.get_params(height=height, width=width)
return _crop_fn(data=data, i=i, j=j, h=h, w=w)
def __repr__(self) -> str:
return "{}(crop_size={}, enable={})".format(
self.__class__.__name__, self.size, self.enable
)
@TRANSFORMATIONS_REGISTRY.register(name="random_horizontal_flip", type="video")
class RandomHorizontalFlip(BaseTransformation):
"""
This class implements random horizontal flipping method
"""
def __init__(self, opts, *args, **kwargs) -> None:
p = getattr(opts, "video_augmentation.random_horizontal_flip.p", 0.5)
super().__init__(opts=opts)
self.p = p
self.enable = getattr(opts, "video_augmentation.random_horizontal_flip.enable")
@classmethod
def add_arguments(cls, parser: argparse.ArgumentParser) -> argparse.ArgumentParser:
group = parser.add_argument_group(title=cls.__name__)
group.add_argument(
"--video-augmentation.random-horizontal-flip.enable",
action="store_true",
help=(
"Use {}. This flag is useful when you want to study the effect of"
" different transforms. Defaults to False.".format(cls.__name__)
),
)
group.add_argument(
"--video-augmentation.random-horizontal-flip.p",
type=float,
default=0.5,
help="Probability for random horizontal flip. Defaults to 0.5.",
)
return parser
def __call__(self, data: Dict) -> Dict:
if random.random() <= self.p:
clip = data["samples"]["video"]
check_rgb_video_tensor(clip=clip)
clip = torch.flip(clip, dims=[-1])
data["samples"]["video"] = clip
mask = data.get("mask", None)
if mask is not None:
mask = torch.flip(mask, dims=[-1])
data["mask"] = mask
return data
def __repr__(self) -> str:
return (
f"{self.__class__.__name__}(flip probability={self.p},"
f" enable={self.enable})"
)
@TRANSFORMATIONS_REGISTRY.register(name="center_crop", type="video")
class CenterCrop(BaseTransformation):
"""
This class implements center cropping method.
.. note::
This class assumes that the input size is greater than or equal to the desired
size.
"""
def __init__(self, opts, size: Sequence or int, *args, **kwargs) -> None:
super().__init__(opts=opts)
if isinstance(size, Sequence) and len(size) == 2:
self.height, self.width = size[0], size[1]
elif isinstance(size, Sequence) and len(size) == 1:
self.height = self.width = size[0]
elif isinstance(size, int):
self.height = self.width = size
else:
logger.error("Scale should be either an int or tuple of ints.")
self.enable = getattr(opts, "video_augmentation.center_crop.enable")
@classmethod
def add_arguments(cls, parser: argparse.ArgumentParser) -> argparse.ArgumentParser:
group = parser.add_argument_group(title=cls.__name__)
group.add_argument(
"--video-augmentation.center-crop.enable",
action="store_true",
help=(
"Use {}. This flag is useful when you want to study the effect of"
" different transforms. Defaults to False.".format(cls.__name__)
),
)
return parser
def __call__(self, data: Dict) -> Dict:
height, width = data["samples"]["video"].shape[-2:]
i = (height - self.height) // 2
j = (width - self.width) // 2
return _crop_fn(data=data, i=i, j=j, h=self.height, w=self.width)
def __repr__(self) -> str:
return "{}(size=(h={}, w={}), enable={})".format(
self.__class__.__name__, self.height, self.width, self.enable
)
@TRANSFORMATIONS_REGISTRY.register(name="resize", type="video")
class Resize(BaseTransformation):
"""
This class implements resizing operation.
.. note::
Two possible modes for resizing.
1. Resize while maintaining aspect ratio. To enable this option, pass int as a size.
2. Resize to a fixed size. To enable this option, pass a tuple of height and width
as a size.
"""
def __init__(self, opts, *args, **kwargs) -> None:
size = getattr(opts, "video_augmentation.resize.size", None)
if size is None:
logger.error("Size can not be None in {}".format(self.__class__.__name__))
# Possible modes.
# 1. Resize while maintaining aspect ratio. To enable this option, pass int as a
# size.
# 2. Resize to a fixed size. To enable this option, pass a tuple of height and
# width as a size.
if isinstance(size, Sequence) and len(size) > 2:
logger.error(
"The length of size should be either 1 or 2 in {}".format(
self.__class__.__name__
)
)
interpolation = getattr(
opts, "video_augmentation.resize.interpolation", "bilinear"
)
super().__init__(opts=opts)
self.size = size
self.interpolation = _check_interpolation(interpolation)
self.enable = getattr(opts, "video_augmentation.resize.enable")
@classmethod
def add_arguments(cls, parser: argparse.ArgumentParser) -> argparse.ArgumentParser:
group = parser.add_argument_group(title=cls.__name__)
group.add_argument(
"--video-augmentation.resize.enable",
action="store_true",
help=(
"Use {}. This flag is useful when you want to study the effect of"
" different transforms. Defaults to False.".format(cls.__name__)
),
)
group.add_argument(
"--video-augmentation.resize.interpolation",
type=str,
default="bilinear",
choices=SUPPORTED_PYTORCH_INTERPOLATIONS,
help="Interpolation for resizing. Defaults to bilinear",
)
group.add_argument(
"--video-augmentation.resize.size",
type=int,
nargs="+",
default=None,
help=(
"Resize video to the specified size. If int is passed, then shorter"
" side is resized to the specified size and longest side is resized"
" while maintaining aspect ratio. Defaults to None."
),
)
return parser
def __call__(self, data: Dict) -> Dict:
return _resize_fn(data=data, size=self.size, interpolation=self.interpolation)
def __repr__(self) -> str:
return "{}(size={}, interpolation={}, enable={})".format(
self.__class__.__name__, self.size, self.interpolation, self.enable
)
@TRANSFORMATIONS_REGISTRY.register(name="crop_by_bounding_box", type="video")
class CropByBoundingBox(BaseTransformation):
"""Crops video frames based on bounding boxes and adjusts the @targets
"box_coordinates" annotations.
Before cropping, the bounding boxes are expanded with @multiplier, while the
"box_coordinates" cover the original areas of the image.
Note that the cropped images may be padded with 0 values in the boundaries of the
cropped image when the bounding boxes are near the edges.
Frames with invalid bounding boxes (with x0=y0=x1=y1=-1, or with area <5) will be
blacked out in the output. Alternatively, we could have dropped them, which is not
implemented yet.
"""
BBOX_MIN_AREA = 5 # Minimum valid bounding box area (in pixels).
def __init__(
self,
opts: argparse.Namespace,
image_size: Optional[Tuple[int, int]] = None,
is_training: bool = False,
*args,
**kwargs,
) -> None:
super().__init__(opts=opts, *args, **kwargs)
self.is_training = is_training
self.multiplier = getattr(
opts, "video_augmentation.crop_by_bounding_box.multiplier"
)
self.multiplier_range = getattr(
opts, "video_augmentation.crop_by_bounding_box.multiplier_range"
)
if image_size is None:
self.image_size = getattr(
opts, "video_augmentation.crop_by_bounding_box.image_size"
)
else:
self.image_size = image_size
assert image_size is not None, (
"Please provide --video-augmentation.crop-by-bounding-box.image_size"
" argument."
)
self.channel_first = getattr(
opts, "video_augmentation.crop_by_bounding_box.channel_first"
)
self.interpolation = getattr(
opts, "video_augmentation.crop_by_bounding_box.interpolation"
)
def __call__(self, data: Dict, *args, **kwargs) -> Dict:
"""
Tensor shape abbreviations:
N: Number of clips.
T, T_audio, T_video: Temporal lengths.
C: Number of color channels.
H, W: Height, Width.
Args:
data: mapping of: {
"samples": {
"video": Tensor of shape: [N, C, T, H, W] if self.channel_first else [N, T, C, H, W]
},
"targets": {
"traces": {
"<object_trace_uuid>": {
"box_coordinates": FloatTensor[N, T, 4], # x0, y0, x1, y1
}
},
"labels": IntTensor[N, T],
}
}
Note:
This transformation does not modify the "labels". If frames that are
blacked out due to having invalid bounding boxes need a different label,
datasets should alter the labels according to the following logic:
```
data = CropByBoundingBox(opts)(data)