take(count) is an action on RDD, which returns an Array with first count items.
Is there a transformation that returns a RDD with first count items? (It is ok if count is approximate)
The best I can get is
val countPerPartition = count / rdd.getNumPartitions.toDouble
rdd.mapPartitions(_.take(countPerPartition))
Update:
I do not want data to be transfered to the driver. In my case, count may be quite large, and driver has not enough memory to hold it. I want the data to remain paralellized for further transformations.