How long will your machine learning project take?
How much data is needed?
I don’t know! But prototyping can give you the answer
In the life as a machine learning specialist I have often found myself in a situation where the customer wants to know whether a particular problem can be solved using machine learning, and if so, how long this will take to develop. I always find these questions difficult to answer because, at the time of asking, I simply do not have enough information to answer the questions with any confidence.
ML is gathering a lot of hype these days and there are a lot of good tutorials online that helps people get started on writing their first ML program. These tutorials usually start by downloading a standard dataset, like MNIST or CIFAR10 on which you train a machine learning algorithm using one of the many frameworks available out there. Rather quickly, you have trained a model that gives you good results. So if it takes 20 minutes to set up a perfectly good machine learning system, that give 90% accuracy, how come it can take weeks or even months to do the same on your own data?
Well, in the real world, you don’t have the luxury of having a perfectly balanced dataset available on which countless of algorithms have been tried. Therefore, it is important to realise that you need to experiment with what you have, and to get started with that, I will suggest some questions to answer in order to get a much better, data-driven approach to estimating the time and needs of your machine learning project.
To take a completely made up example, say that the business wants you to create a feature to determine whether an uploaded photo is of a bird, like in the xkcd comic below. Depending on the situation, the time estimates can range from days to months.
First of all, we need to establish whether machine learning is even needed for this project. If the users are only expected to upload birds, maybe you can simply assume there is a bird on the image? Or is there some way to make the user write or check a box stating whether the images is of a bird or not. Then the problem is simply solved with some UI and the client needs to talk to someone other than me.
Secondly, we need to establish whether a “bird classifier” exists already. Some machine learning models are freely available for use, for example, the new CoreML framework from Apple comes with five pre-trained models ready to use, all of them in the computer vision domain for object detection. So you may be lucky that one of them has a bird detector. In fact, I think most of the models are trained on the ImageNet dataset, which definitely has bird labels. You can use this pre-trained model on your own data and see how well it performs. If it performs well enough, you can happily tell the customer that a bird classifier exists and that the project will only take as long as it takes to integrate the model into their production system.
If the pre-trained models are not sufficient, then we can talk about starting a proper machine learning project. To train a model yourself, you will need data that the model can learn from. In this case, we need images of birds. And not just images of bird, we also need other kinds of images that the users are expected to upload, that are not of birds. Do you have these labeled images already? Fine, we can go ahead and start training a model. If not, well, then you need a strategy for gathering this data and it may take a long time to get enough bird images.
Finally, if all the labeled data is there, then you can actually start training the ML model. There are a lot of different models to choose from, some simple and some very complex, and depending on the problem, the simple models might prove to be sufficient. Which leads to the final question to ask: What accuracy will be acceptable for the classifier? Is it important that it is always correct when classifying an image as a “bird image” (high precision) or is more important that it classifies all possible bird images (high recall). And what precision/recall is acceptable for the project? Is 60% good enough because it is better than no classifications? Then a simple model might suffice. Or is only 90% or above good enough? Then you are probably in for some more serious research and more complex models (and an increased risk of overfitting).
Prototyping gives you the answers
The best way to answer all the above questions is to start working on a prototype. When developing a prototype, you naturally need to answer them as you go along. A prototype is often done in a short time frame, so you need to focus on the essential. The first attempt might be to make a rule-based version, instead of a machine learning approach, of the bird classifier. Say, you figure out a clever set of rules that identifies the bird. If you actually manage to solve the problem using such set of rules, then you might not need machine learning after all.
The next thing is to find a pre-trained model that solves the problem or perhaps you find a pre-trained model that solves a different, related problem (for example, a “cat classifier”). This gives some indication of how well you will be able to solve the real problem. Also, if information about how the pre-trained model is known, you can get a good idea of the amount of data that is needed to train the classifier and you can compare with the data you currently have.
Finally, the prototype gives you some idea of the accuracy you can achieve. I would say that you can usually get quite far, quite quickly, with simple models. But getting the extra couple of percentage points in accuracy can prove to be the majority of the work. But having built the prototype, you are now in a much better position to answer how long it will take you to build the final model and what accuracy you can reasonably expect to achieve and whether you need to collect or label more data.
Now you might ask, how much time do you need to make a prototype? In my opinion, it is best to think about this as a sprint: Set off a fixed amount of time, say, one week or two weeks, formulate the above questions as a backlog items and start working. By the end of the first sprint, you may have trained a Support Vector Machine on your data, that gives an accuracy score of 72%. If this is suitable for a production setting, you can stop the prototype and start planning how to put this into production. Or you may decide to take another sprint trying to improve the accuracy to above 80%. Or perhaps you figured out that the data is way too skewed and needs a lot more collecting and cleaning to make it work. In either case, you have learned a valuable lesson that heavily impacts the schedule and scope of the larger ML project.
The questions to ask
So to sum up, when starting a machine learning project, start by building a prototype to answer these questions:
- Do you even need machine learning?
- Is there already a trained ML model out there you can use?
- If not, do you have the data to train an ML model yourself?
- What accuracy is suitable for the product and is that achievable?
Now go build your first bird classifier!