]> Gitweb @ Texas Instruments - Open Source Git Repositories - git.TI.com/gitweb - jacinto-ai/pytorch-jacinto-ai-devkit.git/blob - modules/pytorch_jacinto_ai/xvision/io/video.py
renamed pytorch_jacinto_ai.vision to pytorch_jacinto_ai.xvision
[jacinto-ai/pytorch-jacinto-ai-devkit.git] / modules / pytorch_jacinto_ai / xvision / io / video.py
1 import re
2 import gc
3 import torch
4 import numpy as np
6 try:
7     import av
8     av.logging.set_level(av.logging.ERROR)
9     if not hasattr(av.video.frame.VideoFrame, 'pict_type'):
10         av = ImportError("""\
11 Your version of PyAV is too old for the necessary video operations in torchvision.
12 If you are on Python 3.5, you will have to build from source (the conda-forge
13 packages are not up-to-date).  See
14 https://github.com/mikeboers/PyAV#installation for instructions on how to
15 install PyAV on your system.
16 """)
17 except ImportError:
18     av = ImportError("""\
19 PyAV is not installed, and is necessary for the video operations in torchvision.
20 See https://github.com/mikeboers/PyAV#installation for instructions on how to
21 install PyAV on your system.
22 """)
25 def _check_av_available():
26     if isinstance(av, Exception):
27         raise av
30 def _av_available():
31     return not isinstance(av, Exception)
34 # PyAV has some reference cycles
35 _CALLED_TIMES = 0
36 _GC_COLLECTION_INTERVAL = 10
39 def write_video(filename, video_array, fps, video_codec='libx264', options=None):
40     """
41     Writes a 4d tensor in [T, H, W, C] format in a video file
43     Parameters
44     ----------
45     filename : str
46         path where the video will be saved
47     video_array : Tensor[T, H, W, C]
48         tensor containing the individual frames, as a uint8 tensor in [T, H, W, C] format
49     fps : Number
50         frames per second
51     """
52     _check_av_available()
53     video_array = torch.as_tensor(video_array, dtype=torch.uint8).numpy()
55     container = av.open(filename, mode='w')
57     stream = container.add_stream(video_codec, rate=fps)
58     stream.width = video_array.shape[2]
59     stream.height = video_array.shape[1]
60     stream.pix_fmt = 'yuv420p' if video_codec != 'libx264rgb' else 'rgb24'
61     stream.options = options or {}
63     for img in video_array:
64         frame = av.VideoFrame.from_ndarray(img, format='rgb24')
65         frame.pict_type = 'NONE'
66         for packet in stream.encode(frame):
67             container.mux(packet)
69     # Flush stream
70     for packet in stream.encode():
71         container.mux(packet)
73     # Close the file
74     container.close()
77 def _read_from_stream(container, start_offset, end_offset, stream, stream_name):
78     global _CALLED_TIMES, _GC_COLLECTION_INTERVAL
79     _CALLED_TIMES += 1
80     if _CALLED_TIMES % _GC_COLLECTION_INTERVAL == _GC_COLLECTION_INTERVAL - 1:
81         gc.collect()
83     frames = {}
84     should_buffer = False
85     max_buffer_size = 5
86     if stream.type == "video":
87         # DivX-style packed B-frames can have out-of-order pts (2 frames in a single pkt)
88         # so need to buffer some extra frames to sort everything
89         # properly
90         extradata = stream.codec_context.extradata
91         # overly complicated way of finding if `divx_packed` is set, following
92         # https://github.com/FFmpeg/FFmpeg/commit/d5a21172283572af587b3d939eba0091484d3263
93         if extradata and b"DivX" in extradata:
94             # can't use regex directly because of some weird characters sometimes...
95             pos = extradata.find(b"DivX")
96             d = extradata[pos:]
97             o = re.search(br"DivX(\d+)Build(\d+)(\w)", d)
98             if o is None:
99                 o = re.search(br"DivX(\d+)b(\d+)(\w)", d)
100             if o is not None:
101                 should_buffer = o.group(3) == b"p"
102     seek_offset = start_offset
103     # some files don't seek to the right location, so better be safe here
104     seek_offset = max(seek_offset - 1, 0)
105     if should_buffer:
106         # FIXME this is kind of a hack, but we will jump to the previous keyframe
107         # so this will be safe
108         seek_offset = max(seek_offset - max_buffer_size, 0)
109     try:
110         # TODO check if stream needs to always be the video stream here or not
111         container.seek(seek_offset, any_frame=False, backward=True, stream=stream)
112     except av.AVError:
113         # TODO add some warnings in this case
114         # print("Corrupted file?", container.name)
115         return []
116     buffer_count = 0
117     for idx, frame in enumerate(container.decode(**stream_name)):
118         frames[frame.pts] = frame
119         if frame.pts >= end_offset:
120             if should_buffer and buffer_count < max_buffer_size:
121                 buffer_count += 1
122                 continue
123             break
124     # ensure that the results are sorted wrt the pts
125     result = [frames[i] for i in sorted(frames) if start_offset <= frames[i].pts <= end_offset]
126     if start_offset > 0 and start_offset not in frames:
127         # if there is no frame that exactly matches the pts of start_offset
128         # add the last frame smaller than start_offset, to guarantee that
129         # we will have all the necessary data. This is most useful for audio
130         first_frame_pts = max(i for i in frames if i < start_offset)
131         result.insert(0, frames[first_frame_pts])
132     return result
135 def _align_audio_frames(aframes, audio_frames, ref_start, ref_end):
136     start, end = audio_frames[0].pts, audio_frames[-1].pts
137     total_aframes = aframes.shape[1]
138     step_per_aframe = (end - start + 1) / total_aframes
139     s_idx = 0
140     e_idx = total_aframes
141     if start < ref_start:
142         s_idx = int((ref_start - start) / step_per_aframe)
143     if end > ref_end:
144         e_idx = int((ref_end - end) / step_per_aframe)
145     return aframes[:, s_idx:e_idx]
148 def read_video(filename, start_pts=0, end_pts=None):
149     """
150     Reads a video from a file, returning both the video frames as well as
151     the audio frames
153     Parameters
154     ----------
155     filename : str
156         path to the video file
157     start_pts : int, optional
158         the start presentation time of the video
159     end_pts : int, optional
160         the end presentation time
162     Returns
163     -------
164     vframes : Tensor[T, H, W, C]
165         the `T` video frames
166     aframes : Tensor[K, L]
167         the audio frames, where `K` is the number of channels and `L` is the
168         number of points
169     info : Dict
170         metadata for the video and audio. Can contain the fields video_fps (float)
171         and audio_fps (int)
172     """
173     _check_av_available()
175     if end_pts is None:
176         end_pts = float("inf")
178     if end_pts < start_pts:
179         raise ValueError("end_pts should be larger than start_pts, got "
180                          "start_pts={} and end_pts={}".format(start_pts, end_pts))
182     container = av.open(filename, metadata_errors='ignore')
183     info = {}
185     video_frames = []
186     if container.streams.video:
187         video_frames = _read_from_stream(container, start_pts, end_pts,
188                                          container.streams.video[0], {'video': 0})
189         info["video_fps"] = float(container.streams.video[0].average_rate)
190     audio_frames = []
191     if container.streams.audio:
192         audio_frames = _read_from_stream(container, start_pts, end_pts,
193                                          container.streams.audio[0], {'audio': 0})
194         info["audio_fps"] = container.streams.audio[0].rate
196     container.close()
198     vframes = [frame.to_rgb().to_ndarray() for frame in video_frames]
199     aframes = [frame.to_ndarray() for frame in audio_frames]
200     vframes = torch.as_tensor(np.stack(vframes))
201     if aframes:
202         aframes = np.concatenate(aframes, 1)
203         aframes = torch.as_tensor(aframes)
204         aframes = _align_audio_frames(aframes, audio_frames, start_pts, end_pts)
205     else:
206         aframes = torch.empty((1, 0), dtype=torch.float32)
208     return vframes, aframes, info
211 def _can_read_timestamps_from_packets(container):
212     extradata = container.streams[0].codec_context.extradata
213     if extradata is None:
214         return False
215     if b"Lavc" in extradata:
216         return True
217     return False
220 def read_video_timestamps(filename):
221     """
222     List the video frames timestamps.
224     Note that the function decodes the whole video frame-by-frame.
226     Parameters
227     ----------
228     filename : str
229         path to the video file
231     Returns
232     -------
233     pts : List[int]
234         presentation timestamps for each one of the frames in the video.
235     video_fps : int
236         the frame rate for the video
238     """
239     _check_av_available()
240     container = av.open(filename, metadata_errors='ignore')
242     video_frames = []
243     video_fps = None
244     if container.streams.video:
245         if _can_read_timestamps_from_packets(container):
246             # fast path
247             video_frames = [x for x in container.demux(video=0) if x.pts is not None]
248         else:
249             video_frames = _read_from_stream(container, 0, float("inf"),
250                                              container.streams.video[0], {'video': 0})
251         video_fps = float(container.streams.video[0].average_rate)
252     container.close()
253     return [x.pts for x in video_frames], video_fps