Using CoreML in ARKit for Object Segmentation and Occlusion

In this post we’ll be looking at how to detect, classify, segment and occlude objects in ARKit using CoreML and Vision Framework.

We’ll use two machine learning models that are available from the Apple Developer website:

  • YOLOv3 to locale and classify an object
  • DeeplabV3 to segment the detected object’s pixels

This example will run on devices that don’t have a LiDAR sensor so we’ll look at a way to ‘fake’ depth in a Metal fragment shader.

Take a look at this video to see what we’ll achieve:

Let’s get started.

Preparing The DeeplabV3 Model

In order to do this you’ll need to have Python installed as well as Apple’s coremtools. Instructions can be found on the GitHub page.

Here’s how to use Python to modify the model (this page provided the inspiration):

import coremltools
import coremltools.proto.FeatureTypes_pb2 as ft
# Load the spec from the machine learning model
spec = coremltools.utils.load_spec("DeepLabV3Int8LUT.mlmodel")
# See the output we'll have to modify
output = spec.description.output[0]# We'll have to use 513
output.type.imageType.height = 513
output.type.imageType.width = 513
# The model doesn't produce RGB values but single integers per pixel
# that represent the object classification
output.type.imageType.colorSpace = ft.ImageFeatureType.GRAYSCALE
# Save the model
coremltools.utils.save_spec(spec, "DeepLabV3Int8Image.mlmodel"

Now we’re ready to start coding.

Using Vision Framework for Detection, Classification and Segmentation

We’ll create two VNCoreMLRequest instances:

  • One for object detection. We use the YOLOv3 model to detect cars. This will give us a classification string (“car”) and a rectangle that represents the bounds of the detected object.
  • One for object segmentation. The rectangle that was produced by object detection is used as the region of interest for this request.

We’ll perform these requests inside SCNSceneRendererDelegate ‘s renderer(_:willRenderScene:atTime:) method:

Note that we’re passing two requests. The object detection request will be processed first. Then we’ll get the rectangle containing the detected object which can then be used in the next request.

Here we go through the detected objects and filter out the ones that are classified as “car”. If an actual car was detected we enlarge the detected object’s bounding box for better segmentation results and then we pass it on to the segmentation request.

Before we go into processing the segmentation results we’ll need to understand what these results represent. Deeplab returns a 513x513 grid that assigns classifications to pixels. The classifications are represented by indices that corresponds to an array of classification labels that are predetermined by Deeplab:

"background", "aeroplane", "bicycle", "bird", "board", "bottle", "bus", "car", "cat", "chair", "cow", "diningTable", "dog", "horse", "motorbike", "person", "pottedPlant", "sheep", "sofa", "train", "tvOrMonitor"

So a background pixel will be represented by 0, a bird by 3 and a car by 7. We’ll check for these values in the fragment shader to determine the visibility of the output pixel and depth in SceneKit.

However, before we can do that we’ll need to convert the CVPixelBuffer to an MTLTexture so we can use it in a Metal fragment shader.

A couple of things to note here. The CVPixelBuffer that is provided to us by VNPixelBufferObservation is not Metal compatible. This means that if we try to convert it to an MTLTexture it will fail. Hence it is necessary to copy the buffer and set kCVPixelBufferMetalCompatibilityKey to true . A very useful method from CoreMLHelpers is used to accomplish this.

Also note the MTLTexture pixel format, r8Uint. As mentioned before, Deeplab doesn’t return rgb values but single integers representing classification indices.

Lets move on to visualizing these results using SceneKit and Metal.

SceneKit and Metal Shader Tricks

There is a nice way to accomplish this using some of the more undocumented bits of SceneKit and the Metal Shading Language. Here’s the breakdown:

  • Create a node that will render as a full screen quad. As you might derive from its name, a full screen quad always covers the entire screen. We’ll use it to mask out the detected object.
  • Use an SCNNode so we can add it to the scene graph. We want to hook into SceneKit’s render cycle but we don’t want SceneKit to actually render our node. We’ll perform our own custom rendering using a Metal shader. This means we cannot attach anySCNGeometry. Our geometry isn’t complex and won’t need any SceneKit-specific information (like the model-view-project matrix) so we can create the geometry in the vertex shader.
  • Use SCNNodeRendererDelegate ‘s renderNode(_:renderer:arguments) to hook into SceneKit’s render cycle and grab the current MTLRenderCommandEncoder so we can render our full screen quad and write into the depth buffer.

In the example, this class is called SegmentationMaskNode. Here’s a summary of the most important bits:

  • We’ll need to create a MTLRenderPipelineDescriptor that has a vertex and a fragment shader attached.
  • In order to be able to write into the depth buffer we’ll also need to create a MTLDepthStencilDescriptor and set isDepthWriteEnabled to true.
  • In the render delegate method renderNode we’ll get the current MTLRenderCommandEncoder.
  • Here we also create our shader uniforms that pass these values on to the shader: region of interest, the value that should be written into the depth buffer, a value that can be used to correct the aspect ratio and the classification label index.
  • Lastly, we’ll tell MTLRenderCommandEncoder to render the quad using the shaders, uniforms and depth stencil state.

As mentioned before, we’re ‘faking’ the depth because we’re not targeting devices that have a LiDAR sensor. So how do we do this?

Determining The Detected Object’s Depth

In the example we’re not using a TrueDepth camera or a LiDAR sensor so we don’t have access to any depth data. Thankfully we do have access to feature points that were detected by ARKit. These give us enough information to derive the object’s depth from.

I used the word ‘fake’ depth because we’ll use the same depth value for every pixel of the detected object. So in essence the object is represented as ‘flat’ in the depth buffer:

A color image on the left with the corresponding depth map on the right. Note that every pixel of the detected object has a uniform depth value.

So how do we get this value? Like so:

  • get the center point of the detected object’s rectangle (a 2D coordinate)
  • use SCNView‘s hitTest(_:types) -> [ARHitTestResult] method to get the closest feature point (a 3D coordinate)
  • get the distance of this feature point to the camera
  • convert this linear distance to a z buffer value

Summarized in code:

Let's move on to the vertex and fragment shaders.

Rendering Occlusion With Metal Shaders

The fragment shader is a bit more interesting. Here’s where we write to the depth buffer. This is possible by creating a struct that has a depth property with the [[depth]] attribute specified.

struct FragmentOut {
float4 color [[color(0)]];
float depth [[depth(any)]];

This is an obscure but very useful feature of the Metal Shading Language that is unfortunately documented poorly.

In the fragment shader we check the classification label index. If it correspond to the object we’re targeting (in this case ‘7’, which corresponds to ‘car’) then we’ll assign a color to the pixel and we’ll write the depth value. Otherwise we call discard_fragment() to make sure nothing is written in either the color or depth buffer:

And that wraps it up. Please mind that detection and segmentation is not always perfect. Some optimizations could be put into place but that goes beyond the scope of this post.

As mentioned earlier, the complete example can be found on GitHub:

AR/VR Developer |

Get the Medium app

A button that says 'Download on the App Store', and if clicked it will lead you to the iOS App store
A button that says 'Get it on, Google Play', and if clicked it will lead you to the Google Play store