1 import re
2 import gc
3 import torch
4 import numpy as np
6 try:
7 import av
8 av.logging.set_level(av.logging.ERROR)
9 if not hasattr(av.video.frame.VideoFrame, 'pict_type'):
10 av = ImportError("""\
11 Your version of PyAV is too old for the necessary video operations in torchvision.
12 If you are on Python 3.5, you will have to build from source (the conda-forge
13 packages are not up-to-date). See
14 https://github.com/mikeboers/PyAV#installation for instructions on how to
15 install PyAV on your system.
16 """)
17 except ImportError:
18 av = ImportError("""\
19 PyAV is not installed, and is necessary for the video operations in torchvision.
20 See https://github.com/mikeboers/PyAV#installation for instructions on how to
21 install PyAV on your system.
22 """)
25 def _check_av_available():
26 if isinstance(av, Exception):
27 raise av
30 def _av_available():
31 return not isinstance(av, Exception)
34 # PyAV has some reference cycles
35 _CALLED_TIMES = 0
36 _GC_COLLECTION_INTERVAL = 10
39 def write_video(filename, video_array, fps, video_codec='libx264', options=None):
40 """
41 Writes a 4d tensor in [T, H, W, C] format in a video file
43 Parameters
44 ----------
45 filename : str
46 path where the video will be saved
47 video_array : Tensor[T, H, W, C]
48 tensor containing the individual frames, as a uint8 tensor in [T, H, W, C] format
49 fps : Number
50 frames per second
51 """
52 _check_av_available()
53 video_array = torch.as_tensor(video_array, dtype=torch.uint8).numpy()
55 container = av.open(filename, mode='w')
57 stream = container.add_stream(video_codec, rate=fps)
58 stream.width = video_array.shape[2]
59 stream.height = video_array.shape[1]
60 stream.pix_fmt = 'yuv420p' if video_codec != 'libx264rgb' else 'rgb24'
61 stream.options = options or {}
63 for img in video_array:
64 frame = av.VideoFrame.from_ndarray(img, format='rgb24')
65 frame.pict_type = 'NONE'
66 for packet in stream.encode(frame):
67 container.mux(packet)
69 # Flush stream
70 for packet in stream.encode():
71 container.mux(packet)
73 # Close the file
74 container.close()
77 def _read_from_stream(container, start_offset, end_offset, stream, stream_name):
78 global _CALLED_TIMES, _GC_COLLECTION_INTERVAL
79 _CALLED_TIMES += 1
80 if _CALLED_TIMES % _GC_COLLECTION_INTERVAL == _GC_COLLECTION_INTERVAL - 1:
81 gc.collect()
83 frames = {}
84 should_buffer = False
85 max_buffer_size = 5
86 if stream.type == "video":
87 # DivX-style packed B-frames can have out-of-order pts (2 frames in a single pkt)
88 # so need to buffer some extra frames to sort everything
89 # properly
90 extradata = stream.codec_context.extradata
91 # overly complicated way of finding if `divx_packed` is set, following
92 # https://github.com/FFmpeg/FFmpeg/commit/d5a21172283572af587b3d939eba0091484d3263
93 if extradata and b"DivX" in extradata:
94 # can't use regex directly because of some weird characters sometimes...
95 pos = extradata.find(b"DivX")
96 d = extradata[pos:]
97 o = re.search(br"DivX(\d+)Build(\d+)(\w)", d)
98 if o is None:
99 o = re.search(br"DivX(\d+)b(\d+)(\w)", d)
100 if o is not None:
101 should_buffer = o.group(3) == b"p"
102 seek_offset = start_offset
103 # some files don't seek to the right location, so better be safe here
104 seek_offset = max(seek_offset - 1, 0)
105 if should_buffer:
106 # FIXME this is kind of a hack, but we will jump to the previous keyframe
107 # so this will be safe
108 seek_offset = max(seek_offset - max_buffer_size, 0)
109 try:
110 # TODO check if stream needs to always be the video stream here or not
111 container.seek(seek_offset, any_frame=False, backward=True, stream=stream)
112 except av.AVError:
113 # TODO add some warnings in this case
114 # print("Corrupted file?", container.name)
115 return []
116 buffer_count = 0
117 for idx, frame in enumerate(container.decode(**stream_name)):
118 frames[frame.pts] = frame
119 if frame.pts >= end_offset:
120 if should_buffer and buffer_count < max_buffer_size:
121 buffer_count += 1
122 continue
123 break
124 # ensure that the results are sorted wrt the pts
125 result = [frames[i] for i in sorted(frames) if start_offset <= frames[i].pts <= end_offset]
126 if start_offset > 0 and start_offset not in frames:
127 # if there is no frame that exactly matches the pts of start_offset
128 # add the last frame smaller than start_offset, to guarantee that
129 # we will have all the necessary data. This is most useful for audio
130 first_frame_pts = max(i for i in frames if i < start_offset)
131 result.insert(0, frames[first_frame_pts])
132 return result
135 def _align_audio_frames(aframes, audio_frames, ref_start, ref_end):
136 start, end = audio_frames[0].pts, audio_frames[-1].pts
137 total_aframes = aframes.shape[1]
138 step_per_aframe = (end - start + 1) / total_aframes
139 s_idx = 0
140 e_idx = total_aframes
141 if start < ref_start:
142 s_idx = int((ref_start - start) / step_per_aframe)
143 if end > ref_end:
144 e_idx = int((ref_end - end) / step_per_aframe)
145 return aframes[:, s_idx:e_idx]
148 def read_video(filename, start_pts=0, end_pts=None):
149 """
150 Reads a video from a file, returning both the video frames as well as
151 the audio frames
153 Parameters
154 ----------
155 filename : str
156 path to the video file
157 start_pts : int, optional
158 the start presentation time of the video
159 end_pts : int, optional
160 the end presentation time
162 Returns
163 -------
164 vframes : Tensor[T, H, W, C]
165 the `T` video frames
166 aframes : Tensor[K, L]
167 the audio frames, where `K` is the number of channels and `L` is the
168 number of points
169 info : Dict
170 metadata for the video and audio. Can contain the fields video_fps (float)
171 and audio_fps (int)
172 """
173 _check_av_available()
175 if end_pts is None:
176 end_pts = float("inf")
178 if end_pts < start_pts:
179 raise ValueError("end_pts should be larger than start_pts, got "
180 "start_pts={} and end_pts={}".format(start_pts, end_pts))
182 container = av.open(filename, metadata_errors='ignore')
183 info = {}
185 video_frames = []
186 if container.streams.video:
187 video_frames = _read_from_stream(container, start_pts, end_pts,
188 container.streams.video[0], {'video': 0})
189 info["video_fps"] = float(container.streams.video[0].average_rate)
190 audio_frames = []
191 if container.streams.audio:
192 audio_frames = _read_from_stream(container, start_pts, end_pts,
193 container.streams.audio[0], {'audio': 0})
194 info["audio_fps"] = container.streams.audio[0].rate
196 container.close()
198 vframes = [frame.to_rgb().to_ndarray() for frame in video_frames]
199 aframes = [frame.to_ndarray() for frame in audio_frames]
200 vframes = torch.as_tensor(np.stack(vframes))
201 if aframes:
202 aframes = np.concatenate(aframes, 1)
203 aframes = torch.as_tensor(aframes)
204 aframes = _align_audio_frames(aframes, audio_frames, start_pts, end_pts)
205 else:
206 aframes = torch.empty((1, 0), dtype=torch.float32)
208 return vframes, aframes, info
211 def _can_read_timestamps_from_packets(container):
212 extradata = container.streams[0].codec_context.extradata
213 if extradata is None:
214 return False
215 if b"Lavc" in extradata:
216 return True
217 return False
220 def read_video_timestamps(filename):
221 """
222 List the video frames timestamps.
224 Note that the function decodes the whole video frame-by-frame.
226 Parameters
227 ----------
228 filename : str
229 path to the video file
231 Returns
232 -------
233 pts : List[int]
234 presentation timestamps for each one of the frames in the video.
235 video_fps : int
236 the frame rate for the video
238 """
239 _check_av_available()
240 container = av.open(filename, metadata_errors='ignore')
242 video_frames = []
243 video_fps = None
244 if container.streams.video:
245 if _can_read_timestamps_from_packets(container):
246 # fast path
247 video_frames = [x for x in container.demux(video=0) if x.pts is not None]
248 else:
249 video_frames = _read_from_stream(container, 0, float("inf"),
250 container.streams.video[0], {'video': 0})
251 video_fps = float(container.streams.video[0].average_rate)
252 container.close()
253 return [x.pts for x in video_frames], video_fps